make fully stateful in module genome.

This commit is contained in:
wls2002
2024-05-25 16:19:06 +08:00
parent 625c261a49
commit 485d481745
3 changed files with 20 additions and 17 deletions

View File

@@ -32,7 +32,7 @@ class RecurrentGenome(BaseGenome):
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# remove un-enable connections and remove enable attr
@@ -41,7 +41,7 @@ class RecurrentGenome(BaseGenome):
return nodes, u_conns
def forward(self, inputs, transformed):
def forward(self, state, inputs, transformed):
nodes, conns = transformed
N = nodes.shape[0]
@@ -56,17 +56,17 @@ class RecurrentGenome(BaseGenome):
node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(1, None)
in_axes=(None, 1, None)
),
in_axes=(1, 0)
)(conns, values)
in_axes=(None, 1, 0)
)(state, conns, values)
# calculate nodes
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)