make fully stateful in module genome.
This commit is contained in:
@@ -32,7 +32,7 @@ class RecurrentGenome(BaseGenome):
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
@@ -41,7 +41,7 @@ class RecurrentGenome(BaseGenome):
|
||||
|
||||
return nodes, u_conns
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
def forward(self, state, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
|
||||
N = nodes.shape[0]
|
||||
@@ -56,17 +56,17 @@ class RecurrentGenome(BaseGenome):
|
||||
node_ins = jax.vmap(
|
||||
jax.vmap(
|
||||
self.conn_gene.forward,
|
||||
in_axes=(1, None)
|
||||
in_axes=(None, 1, None)
|
||||
),
|
||||
in_axes=(1, 0)
|
||||
)(conns, values)
|
||||
in_axes=(None, 1, 0)
|
||||
)(state, conns, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
Reference in New Issue
Block a user