From 485d4817457832eba4cea1b2a5c83286caeff4a5 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 25 May 2024 16:19:06 +0800 Subject: [PATCH] make fully stateful in module genome. --- tensorneat/algorithm/neat/genome/base.py | 17 ++++++++++------- tensorneat/algorithm/neat/genome/default.py | 8 ++++---- tensorneat/algorithm/neat/genome/recurrent.py | 12 ++++++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index 928c0f5..b704f45 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene -from utils import fetch_first +from utils import fetch_first, State class BaseGenome: @@ -24,13 +24,16 @@ class BaseGenome: self.node_gene = node_gene self.conn_gene = conn_gene - def transform(self, nodes, conns): + def setup(self, state=State()): + return state + + def transform(self, state, nodes, conns): raise NotImplementedError - def forward(self, inputs, transformed): + def forward(self, state, inputs, transformed): raise NotImplementedError - def add_node(self, nodes, new_key: int, attrs): + def add_node(self, state, nodes, new_key: int, attrs): """ Add a new node to the genome. The new node will place at the first NaN row. @@ -40,14 +43,14 @@ class BaseGenome: new_nodes = nodes.at[pos, 0].set(new_key) return new_nodes.at[pos, 1:].set(attrs) - def delete_node_by_pos(self, nodes, pos): + def delete_node_by_pos(self, state, nodes, pos): """ Delete a node from the genome. Delete the node by its pos in nodes. """ return nodes.at[pos].set(jnp.nan) - def add_conn(self, conns, i_key, o_key, enable: bool, attrs): + def add_conn(self, state, conns, i_key, o_key, enable: bool, attrs): """ Add a new connection to the genome. The new connection will place at the first NaN row. @@ -57,7 +60,7 @@ class BaseGenome: new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) return new_conns.at[pos, 3:].set(attrs) - def delete_conn_by_pos(self, conns, pos): + def delete_conn_by_pos(self, state, conns, pos): """ Delete a connection from the genome. Delete the connection by its idx. diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 751b52c..0425938 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -30,7 +30,7 @@ class DefaultGenome(BaseGenome): raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform - def transform(self, nodes, conns): + def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) conn_enable = u_conns[0] == 1 @@ -40,7 +40,7 @@ class DefaultGenome(BaseGenome): return seqs, nodes, u_conns - def forward(self, inputs, transformed): + def forward(self, state, inputs, transformed): cal_seqs, nodes, conns = transformed N = nodes.shape[0] @@ -57,8 +57,8 @@ class DefaultGenome(BaseGenome): i = cal_seqs[idx] def hit(): - ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) + ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values) + z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) new_values = values.at[i].set(z) return new_values diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 5ed4737..8469884 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -32,7 +32,7 @@ class RecurrentGenome(BaseGenome): raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform - def transform(self, nodes, conns): + def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) # remove un-enable connections and remove enable attr @@ -41,7 +41,7 @@ class RecurrentGenome(BaseGenome): return nodes, u_conns - def forward(self, inputs, transformed): + def forward(self, state, inputs, transformed): nodes, conns = transformed N = nodes.shape[0] @@ -56,17 +56,17 @@ class RecurrentGenome(BaseGenome): node_ins = jax.vmap( jax.vmap( self.conn_gene.forward, - in_axes=(1, None) + in_axes=(None, 1, None) ), - in_axes=(1, 0) - )(conns, values) + in_axes=(None, 1, 0) + )(state, conns, values) # calculate nodes is_output_nodes = jnp.isin( jnp.arange(N), self.output_idx ) - values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes) + values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes) return values vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)