All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
|
||||
conn_enable = u_conns[0] == 1
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
return state, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
|
||||
vals = jnp.full((N,), jnp.nan)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def body_func(_, values):
|
||||
def body_func(_, carry):
|
||||
state_, values = carry
|
||||
|
||||
# set input values
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = jax.vmap(
|
||||
state_, node_ins = jax.vmap(
|
||||
jax.vmap(
|
||||
self.conn_gene.forward,
|
||||
in_axes=(None, 1, None)
|
||||
in_axes=(None, 1, None),
|
||||
out_axes=(None, 0)
|
||||
),
|
||||
in_axes=(None, 1, 0)
|
||||
)(state, conns, values)
|
||||
in_axes=(None, 1, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, conns, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
state_, values = jax.vmap(
|
||||
self.node_gene.forward,
|
||||
in_axes=(None, 0, 0, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
return state_, values
|
||||
|
||||
return vals[self.output_idx]
|
||||
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
|
||||
|
||||
return state, vals[self.output_idx]
|
||||
|
||||
Reference in New Issue
Block a user