All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
return state, nodes, u_conns
def forward(self, state, inputs, transformed):
nodes, conns = transformed
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
def body_func(_, carry):
state_, values = carry
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
state_, node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(None, 1, None)
in_axes=(None, 1, None),
out_axes=(None, 0)
),
in_axes=(None, 1, 0)
)(state, conns, values)
in_axes=(None, 1, 0),
out_axes=(None, 0)
)(state_, conns, values)
# calculate nodes
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
return values
state_, values = jax.vmap(
self.node_gene.forward,
in_axes=(None, 0, 0, 0),
out_axes=(None, 0)
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return state_, values
return vals[self.output_idx]
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
return state, vals[self.output_idx]