update recurrent genome
This commit is contained in:
@@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
||||
return (idx < self.max_nodes) & (
|
||||
cal_seqs[idx] != I_INF
|
||||
) # not out of bounds and next node exists
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def input_node():
|
||||
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
return values
|
||||
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
|
||||
# calculate nodes
|
||||
z = self.node_gene.forward(
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(i, self.output_idx),
|
||||
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
|
||||
Reference in New Issue
Block a user