diff --git a/tensorneat/algorithm/neat/genome/__init__.py b/tensorneat/algorithm/neat/genome/__init__.py index 74bd7ba..bf09c21 100644 --- a/tensorneat/algorithm/neat/genome/__init__.py +++ b/tensorneat/algorithm/neat/genome/__init__.py @@ -1,4 +1,5 @@ from .base import BaseGenome from .default import DefaultGenome from .recurrent import RecurrentGenome -from .advance import AdvanceInitialize \ No newline at end of file +from .advance import AdvanceInitialize +from .dense import DenseInitialize diff --git a/tensorneat/algorithm/neat/genome/advance.py b/tensorneat/algorithm/neat/genome/advance.py new file mode 100644 index 0000000..3dbce03 --- /dev/null +++ b/tensorneat/algorithm/neat/genome/advance.py @@ -0,0 +1,70 @@ +import jax, jax.numpy as jnp +from .default import DefaultGenome + + +class AdvanceInitialize(DefaultGenome): + def __init__(self, hidden_cnt=8, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hidden_cnt = hidden_cnt + + def initialize(self, state, randkey): + + k1, k2 = jax.random.split(randkey, num=2) + + input_idx, output_idx = self.input_idx, self.output_idx + input_size = len(input_idx) + output_size = len(output_idx) + + hidden_idx = jnp.arange( + input_size + output_size, input_size + output_size + self.hidden_cnt + ) + nodes = jnp.full( + (self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32 + ) + + nodes = nodes.at[input_idx, 0].set(input_idx) + nodes = nodes.at[output_idx, 0].set(output_idx) + nodes = nodes.at[hidden_idx, 0].set(hidden_idx) + + total_idx = input_size + output_size + self.hidden_cnt + rand_keys_n = jax.random.split(k1, num=total_idx) + + node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0)) + node_attrs = node_attr_func(state, rand_keys_n) + nodes = nodes.at[:total_idx, 1:].set(node_attrs) + + conns = jnp.full( + (self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32 + ) + + input_to_hidden_ids, hidden_ids = jnp.meshgrid( + input_idx, hidden_idx, indexing="ij" + ) + total_input_to_hidden_conns = input_size * self.hidden_cnt + conns = conns.at[:total_input_to_hidden_conns, :2].set( + jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()]) + ) + + hidden_to_output_ids, output_ids = jnp.meshgrid( + hidden_idx, output_idx, indexing="ij" + ) + total_hidden_to_output_conns = self.hidden_cnt * output_size + conns = conns.at[ + total_input_to_hidden_conns : total_input_to_hidden_conns + + total_hidden_to_output_conns, + :2, + ].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()])) + + total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns + rand_keys_c = jax.random.split(k2, num=total_conns) + conns_attr_func = jax.vmap( + self.conn_gene.new_random_attrs, + in_axes=( + None, + 0, + ), + ) + conns_attrs = conns_attr_func(state, rand_keys_c) + conns = conns.at[:total_conns, 2:].set(conns_attrs) + + return nodes, conns diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 137ac95..6ccb1ec 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -2,6 +2,7 @@ import warnings from typing import Callable import jax, jax.numpy as jnp +import numpy as np import sympy as sp from utils import ( unflatten_conns, @@ -37,6 +38,7 @@ class DefaultGenome(BaseGenome): mutation: BaseMutation = DefaultMutation(), crossover: BaseCrossover = DefaultCrossover(), output_transform: Callable = None, + input_transform: Callable = None, ): super().__init__( num_inputs, @@ -49,9 +51,16 @@ class DefaultGenome(BaseGenome): crossover, ) + if input_transform is not None: + try: + _ = input_transform(np.zeros(num_inputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + self.input_transform = input_transform + if output_transform is not None: try: - _ = output_transform(jnp.zeros(num_outputs)) + _ = output_transform(np.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform @@ -69,6 +78,10 @@ class DefaultGenome(BaseGenome): return nodes, conns def forward(self, state, transformed, inputs): + + if self.input_transform is not None: + inputs = self.input_transform(inputs) + cal_seqs, nodes, conns, u_conns = transformed ini_vals = jnp.full((self.max_nodes,), jnp.nan) @@ -118,6 +131,10 @@ class DefaultGenome(BaseGenome): return self.output_transform(vals[self.output_idx]) def update_by_batch(self, state, batch_input, transformed): + + if self.input_transform is not None: + batch_input = jax.vmap(self.input_transform)(batch_input) + cal_seqs, nodes, conns, u_conns = transformed batch_size = batch_input.shape[0] @@ -193,11 +210,19 @@ class DefaultGenome(BaseGenome): new_transformed, ) - def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"): + def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"): assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP + if sympy_input_transform is None and self.input_transform is not None: + warnings.warn( + "genome.input_transform is not None but sympy_input_transform is None!" + ) + if sympy_input_transform is not None: + if not isinstance(sympy_input_transform, list): + sympy_input_transform = [sympy_input_transform] * self.num_inputs + if sympy_output_transform is None and self.output_transform is not None: warnings.warn( "genome.output_transform is not None but sympy_output_transform is None!" @@ -221,7 +246,10 @@ class DefaultGenome(BaseGenome): for i in order: if i in input_idx: - nodes_exprs[symbols[i]] = symbols[i] + if sympy_input_transform is not None: + nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i]) + else: + nodes_exprs[symbols[i]] = symbols[i] else: in_conns = [c for c in network["conns"] if c[1] == i] node_inputs = [] diff --git a/tensorneat/algorithm/neat/genome/dense.py b/tensorneat/algorithm/neat/genome/dense.py new file mode 100644 index 0000000..1c47ef6 --- /dev/null +++ b/tensorneat/algorithm/neat/genome/dense.py @@ -0,0 +1,56 @@ +import jax, jax.numpy as jnp +from .default import DefaultGenome + + +class DenseInitialize(DefaultGenome): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.max_nodes >= self.num_inputs + self.num_outputs + assert self.max_conns >= self.num_inputs * self.num_outputs + + def initialize(self, state, randkey): + + k1, k2 = jax.random.split(randkey, num=2) + + input_idx, output_idx = self.input_idx, self.output_idx + input_size = len(input_idx) + output_size = len(output_idx) + + nodes = jnp.full( + (self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32 + ) + + nodes = nodes.at[input_idx, 0].set(input_idx) + nodes = nodes.at[output_idx, 0].set(output_idx) + + total_idx = input_size + output_size + rand_keys_n = jax.random.split(k1, num=total_idx) + + node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0)) + node_attrs = node_attr_func(state, rand_keys_n) + nodes = nodes.at[:total_idx, 1:].set(node_attrs) + + conns = jnp.full( + (self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32 + ) + + input_to_output_ids, output_ids = jnp.meshgrid( + input_idx, output_idx, indexing="ij" + ) + total_conns = input_size * output_size + conns = conns.at[:total_conns, :2].set( + jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()]) + ) + + rand_keys_c = jax.random.split(k2, num=total_conns) + conns_attr_func = jax.vmap( + self.conn_gene.new_random_attrs, + in_axes=( + None, + 0, + ), + ) + conns_attrs = conns_attr_func(state, rand_keys_c) + conns = conns.at[:total_conns, 2:].set(conns_attrs) + + return nodes, conns diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index b2f8df0..fd96b5a 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -20,7 +20,6 @@ name2sympy = { "relu": SympyRelu, "lelu": SympyLelu, "identity": SympyIdentity, - "clamped": SympyClamped, "inv": SympyInv, "log": SympyLog, "exp": SympyExp, diff --git a/tensorneat/utils/activation/act_jnp.py b/tensorneat/utils/activation/act_jnp.py index 2da5d7b..218ed3c 100644 --- a/tensorneat/utils/activation/act_jnp.py +++ b/tensorneat/utils/activation/act_jnp.py @@ -2,6 +2,9 @@ import jax import jax.numpy as jnp +sigma_3 = 2.576 + + class Act: @staticmethod def name2func(name): @@ -9,35 +12,42 @@ class Act: @staticmethod def sigmoid(z): - z = jnp.clip(5 * z, -10, 10) - return 1 / (1 + jnp.exp(-z)) + z = jnp.clip(5 * z / sigma_3, -5, 5) + z = 1 / (1 + jnp.exp(-z)) + + return z * sigma_3 # (0, sigma_3) @staticmethod def tanh(z): - z = jnp.clip(0.6*z, -3, 3) - return jnp.tanh(z) + z = jnp.clip(5 * z / sigma_3, -5, 5) + return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3) + + @staticmethod + def standard_tanh(z): + z = jnp.clip(5 * z / sigma_3, -5, 5) + return jnp.tanh(z) # (-1, 1) @staticmethod def sin(z): - return jnp.sin(z) + z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2) + return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3) @staticmethod def relu(z): - return jnp.maximum(z, 0) + z = jnp.clip(z, -sigma_3, sigma_3) + return jnp.maximum(z, 0) # (0, sigma_3) @staticmethod def lelu(z): leaky = 0.005 + z = jnp.clip(z, -sigma_3, sigma_3) return jnp.where(z > 0, z, leaky * z) @staticmethod def identity(z): + z = jnp.clip(z, -sigma_3, sigma_3) return z - @staticmethod - def clamped(z): - return jnp.clip(z, -1, 1) - @staticmethod def inv(z): z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7)) @@ -55,6 +65,7 @@ class Act: @staticmethod def abs(z): + z = jnp.clip(z, -1, 1) return jnp.abs(z) @@ -65,7 +76,6 @@ ACT_ALL = ( Act.relu, Act.lelu, Act.identity, - Act.clamped, Act.inv, Act.log, Act.exp, diff --git a/tensorneat/utils/activation/act_sympy.py b/tensorneat/utils/activation/act_sympy.py index 9a21b00..ae5f3ae 100644 --- a/tensorneat/utils/activation/act_sympy.py +++ b/tensorneat/utils/activation/act_sympy.py @@ -2,6 +2,9 @@ import sympy as sp import numpy as np +sigma_3 = 2.576 + + class SympyClip(sp.Function): @classmethod def eval(cls, val, min_val, max_val): @@ -26,14 +29,17 @@ class SympySigmoid(sp.Function): @classmethod def eval(cls, z): if z.is_Number: - z = SympyClip(5 * z, -10, 10) - return 1 / (1 + sp.exp(-z)) + z = SympyClip(5 * z / sigma_3, -5, 5) + z = 1 / (1 + sp.exp(-z)) + return z * sigma_3 return None @staticmethod def numerical_eval(z, backend=np): - z = backend.clip(5 * z, -10, 10) - return 1 / (1 + backend.exp(-z)) + z = backend.clip(5 * z / sigma_3, -5, 5) + z = 1 / (1 + backend.exp(-z)) + + return z * sigma_3 # (0, sigma_3) def _sympystr(self, printer): return f"sigmoid({self.args[0]})" @@ -46,36 +52,56 @@ class SympyTanh(sp.Function): @classmethod def eval(cls, z): if z.is_Number: - z = SympyClip(0.6 * z, -3, 3) + z = SympyClip(5 * z / sigma_3, -5, 5) + return sp.tanh(z) * sigma_3 + return None + + @staticmethod + def numerical_eval(z, backend=np): + z = backend.clip(5 * z / sigma_3, -5, 5) + return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3) + + +class SympyStandardTanh(sp.Function): + @classmethod + def eval(cls, z): + if z.is_Number: + z = SympyClip(5 * z / sigma_3, -5, 5) return sp.tanh(z) return None @staticmethod def numerical_eval(z, backend=np): - z = backend.clip(0.6*z, -3, 3) - return backend.tanh(z) + z = backend.clip(5 * z / sigma_3, -5, 5) + return backend.tanh(z) # (-1, 1) class SympySin(sp.Function): @classmethod def eval(cls, z): - return sp.sin(z) + if z.is_Number: + z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2) + return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3) + return None @staticmethod def numerical_eval(z, backend=np): - return backend.sin(z) + z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2) + return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3) class SympyRelu(sp.Function): @classmethod def eval(cls, z): if z.is_Number: - return sp.Piecewise((z, z > 0), (0, True)) + z = SympyClip(z, -sigma_3, sigma_3) + return sp.Max(z, 0) # (0, sigma_3) return None @staticmethod def numerical_eval(z, backend=np): - return backend.maximum(z, 0) + z = backend.clip(z, -sigma_3, sigma_3) + return backend.maximum(z, 0) # (0, sigma_3) def _sympystr(self, printer): return f"relu({self.args[0]})" @@ -107,21 +133,14 @@ class SympyLelu(sp.Function): class SympyIdentity(sp.Function): @classmethod def eval(cls, z): - return z + if z.is_Number: + z = SympyClip(z, -sigma_3, sigma_3) + return z + return None @staticmethod def numerical_eval(z, backend=np): - return z - - -class SympyClamped(sp.Function): - @classmethod - def eval(cls, z): - return SympyClip(z, -1, 1) - - @staticmethod - def numerical_eval(z, backend=np): - return backend.clip(z, -1, 1) + return backend.clip(z, -sigma_3, sigma_3) class SympyInv(sp.Function):