From 649d4b0552cc8c41f7a999a790cb787d20370043 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 10 Jul 2024 16:27:49 +0800 Subject: [PATCH] update recurrent genome --- examples/tmp.py | 15 +- examples/tmp2.py | 16 + network.svg | 415 ++++++++++++++++++ tensorneat/algorithm/neat/gene/base.py | 3 - tensorneat/algorithm/neat/genome/base.py | 6 +- tensorneat/algorithm/neat/genome/default.py | 15 +- tensorneat/algorithm/neat/genome/recurrent.py | 57 +-- tensorneat/common/tools.py | 9 +- 8 files changed, 490 insertions(+), 46 deletions(-) create mode 100644 examples/tmp2.py create mode 100644 network.svg diff --git a/examples/tmp.py b/examples/tmp.py index 0f7ddae..e900e5a 100644 --- a/examples/tmp.py +++ b/examples/tmp.py @@ -1,10 +1,21 @@ import jax, jax.numpy as jnp from tensorneat.algorithm import NEAT -from tensorneat.algorithm.neat import DefaultGenome +from tensorneat.algorithm.neat import DefaultGenome, RecurrentGenome key = jax.random.key(0) -genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=()) +genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3)) state = genome.setup() nodes, conns = genome.initialize(state, key) print(genome.repr(state, nodes, conns)) + +inputs = jnp.array([1, 2, 3, 4, 5]) +transformed = genome.transform(state, nodes, conns) +outputs = genome.forward(state, transformed, inputs) + +print(outputs) + +network = genome.network_dict(state, nodes, conns) +print(network) + +genome.visualize(network) diff --git a/examples/tmp2.py b/examples/tmp2.py new file mode 100644 index 0000000..26d752e --- /dev/null +++ b/examples/tmp2.py @@ -0,0 +1,16 @@ +import jax, jax.numpy as jnp + +arr = jnp.ones((10, 10)) +a = jnp.array([ + [1, 2, 3], + [4, 5, 6] +]) + +def attach_with_inf(arr, idx): + target_dim = arr.ndim + idx.ndim - 1 + expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim))) + + return jnp.where(expand_idx == 1, jnp.nan, arr[idx]) + +b = attach_with_inf(arr, a) +print(b) \ No newline at end of file diff --git a/network.svg b/network.svg new file mode 100644 index 0000000..2989755 --- /dev/null +++ b/network.svg @@ -0,0 +1,415 @@ + + + + + + + + 2024-07-10T15:27:16.806503 + image/svg+xml + + + Matplotlib v3.9.0, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 0625c88..acccc8e 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -34,9 +34,6 @@ class BaseGene(StatefulBaseClass): def forward(self, state, attrs, inputs): raise NotImplementedError - def update_by_batch(self, state, attrs, batch_inputs): - raise NotImplementedError - @property def length(self): return len(self.fixed_attrs) + len(self.custom_attrs) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index aa716c3..502d3c6 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -31,7 +31,7 @@ class BaseGenome(StatefulBaseClass): input_transform: Callable = None, init_hidden_layers: Sequence[int] = (), ): - + # check transform functions if input_transform is not None: try: @@ -64,7 +64,7 @@ class BaseGenome(StatefulBaseClass): all_init_conns_in_idx.append(in_idx) all_init_conns_out_idx.append(out_idx) all_init_nodes.extend(in_layer) - all_init_nodes.extend(layer_indices[-1]) + all_init_nodes.extend(layer_indices[-1]) # output layer if max_nodes < len(all_init_nodes): raise ValueError( @@ -75,7 +75,7 @@ class BaseGenome(StatefulBaseClass): raise ValueError( f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}" ) - + self.num_inputs = num_inputs self.num_outputs = num_outputs self.max_nodes = max_nodes diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 38dc38b..de29eec 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome): def cond_fun(carry): values, idx = carry - return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) + return (idx < self.max_nodes) & ( + cal_seqs[idx] != I_INF + ) # not out of bounds and next node exists def body_func(carry): values, idx = carry i = cal_seqs[idx] def input_node(): - z = self.node_gene.input_transform(state, nodes_attrs[i], values[i]) - new_values = values.at[i].set(z) - return new_values + return values def otherwise(): + # calculate connections conn_indices = u_conns[:, i] - hit_attrs = attach_with_inf(conns_attrs, conn_indices) + hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( state, hit_attrs, values ) + # calculate nodes z = self.node_gene.forward( state, nodes_attrs[i], ins, - is_output_node=jnp.isin(i, self.output_idx), + is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes ) + # set new value new_values = values.at[i].set(z) return new_values diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 6509a99..a289e8f 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,12 +1,13 @@ -from typing import Callable - -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp from .utils import unflatten_conns -from . import BaseGenome +from .base import BaseGenome +from .operations import DefaultMutation, DefaultCrossover, DefaultDistance +from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs from ..gene import DefaultNodeGene, DefaultConnGene -from .operations import DefaultMutation, DefaultCrossover +from tensorneat.common import attach_with_inf class RecurrentGenome(BaseGenome): """Default genome class, with the same behavior as the NEAT-Python""" @@ -17,14 +18,17 @@ class RecurrentGenome(BaseGenome): self, num_inputs: int, num_outputs: int, - max_nodes = 50, - max_conns = 100, + max_nodes=50, + max_conns=100, node_gene=DefaultNodeGene(), conn_gene=DefaultConnGene(), mutation=DefaultMutation(), crossover=DefaultCrossover(), + distance=DefaultDistance(), + output_transform=None, + input_transform=None, + init_hidden_layers=(), activate_time=10, - output_transform: Callable = None, ): super().__init__( num_inputs, @@ -35,29 +39,25 @@ class RecurrentGenome(BaseGenome): conn_gene, mutation, crossover, + distance, + output_transform, + input_transform, + init_hidden_layers, ) self.activate_time = activate_time - if output_transform is not None: - try: - _ = output_transform(jnp.zeros(num_outputs)) - except Exception as e: - raise ValueError(f"Output transform function failed: {e}") - self.output_transform = output_transform - def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) return nodes, conns, u_conns - def restore(self, state, transformed): + def forward(self, state, transformed, inputs): nodes, conns, u_conns = transformed - return nodes, conns - - def forward(self, state, inputs, transformed): - nodes, conns = transformed vals = jnp.full((self.max_nodes,), jnp.nan) - nodes_attrs = nodes[:, 1:] # remove index + + nodes_attrs = vmap(extract_node_attrs)(nodes) + conns_attrs = vmap(extract_conn_attrs)(conns) + expand_conns_attrs = attach_with_inf(conns_attrs, u_conns) def body_func(_, values): @@ -65,14 +65,14 @@ class RecurrentGenome(BaseGenome): values = values.at[self.input_idx].set(inputs) # calculate connections - node_ins = jax.vmap( - jax.vmap(self.conn_gene.forward, in_axes=(None, 1, None)), - in_axes=(None, 1, 0), - )(state, conns, values) + node_ins = vmap( + vmap(self.conn_gene.forward, in_axes=(None, 0, None)), + in_axes=(None, 0, 0), + )(state, expand_conns_attrs, values) # calculate nodes - is_output_nodes = jnp.isin(jnp.arange(self.max_nodes), self.output_idx) - values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))( + is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx) + values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))( state, nodes_attrs, node_ins.T, is_output_nodes ) @@ -87,3 +87,6 @@ class RecurrentGenome(BaseGenome): def sympy_func(self, state, network, precision=3): raise ValueError("Sympy function is not supported for Recurrent Network!") + + def visualize(self, network): + raise ValueError("Visualize function is not supported for Recurrent Network!") diff --git a/tensorneat/common/tools.py b/tensorneat/common/tools.py index b26ebe3..ce27176 100644 --- a/tensorneat/common/tools.py +++ b/tensorneat/common/tools.py @@ -6,12 +6,11 @@ from jax import numpy as jnp, Array, jit, vmap I_INF = np.iinfo(jnp.int32).max # infinite int -# TODO: strange implementation + def attach_with_inf(arr, idx): - expand_size = arr.ndim - idx.ndim - expand_idx = jnp.expand_dims( - idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim)) - ) + target_dim = arr.ndim + idx.ndim - 1 + expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim))) + return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])