update recurrent genome

This commit is contained in:
root
2024-07-10 16:27:49 +08:00
parent 1d606eb1c3
commit 649d4b0552
8 changed files with 490 additions and 46 deletions

View File

@@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome):
def cond_fun(carry):
values, idx = carry
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
return (idx < self.max_nodes) & (
cal_seqs[idx] != I_INF
) # not out of bounds and next node exists
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def input_node():
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
new_values = values.at[i].set(z)
return new_values
return values
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
# calculate nodes
z = self.node_gene.forward(
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(i, self.output_idx),
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value
new_values = values.at[i].set(z)
return new_values