diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py index ac48bad..902c8c0 100644 --- a/algorithm/hyperneat/hyperneat.py +++ b/algorithm/hyperneat/hyperneat.py @@ -6,7 +6,7 @@ import numpy as np from config import Config, HyperNeatConfig from core import Algorithm, Substrate, State, Genome, Gene -from utils import Activation, Aggregation +from utils import Act, Agg from .substrate import analysis_substrate from algorithm import NEAT @@ -90,10 +90,7 @@ class HyperNEATGene: @staticmethod def forward(config: HyperNeatConfig, state: State, inputs, transformed): - act = Activation.name2func[config.activation] - agg = Aggregation.name2func[config.aggregation] - - batch_act, batch_agg = jax.vmap(act), jax.vmap(agg) + batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation) nodes, weights = transformed diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 2d9caf0..fcff108 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -6,7 +6,7 @@ from jax import Array, numpy as jnp from config import GeneConfig from core import Gene, Genome, State -from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg +from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg @dataclass(frozen=True) @@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig): response_mutate_rate: float = 0.7 response_replace_rate: float = 0.1 - activation_default: str = 'sigmoid' - activation_options: Tuple = ('sigmoid',) + activation_default: callable = Act.sigmoid + activation_options: Tuple = (Act.sigmoid, ) activation_replace_rate: float = 0.1 - aggregation_default: str = 'sum' - aggregation_options: Tuple = ('sum',) + aggregation_default: callable = Agg.sum + aggregation_options: Tuple = (Agg.sum, ) aggregation_replace_rate: float = 0.1 weight_init_mean: float = 0.0 @@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig): assert self.response_replace_rate >= 0.0 assert self.activation_default == self.activation_options[0] - - for name in self.activation_options: - assert name in Activation.name2func, f"Activation function: {name} not found" - assert self.aggregation_default == self.aggregation_options[0] - assert self.aggregation_default in Aggregation.name2func, \ - f"Aggregation function: {self.aggregation_default} not found" - - for name in self.aggregation_options: - assert name in Aggregation.name2func, f"Aggregation function: {name} not found" - class NormalGene(Gene): node_attrs = ['bias', 'response', 'aggregation', 'activation'] @@ -68,8 +58,6 @@ class NormalGene(Gene): def __init__(self, config: NormalGeneConfig = NormalGeneConfig()): self.config = config - self.act_funcs = [Activation.name2func[name] for name in config.activation_options] - self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] def setup(self, state: State = State()): return state.update( @@ -170,9 +158,9 @@ class NormalGene(Gene): def hit(): ins = values * weights[:, i] - z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins) + z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - z = act(nodes[i, 3], z, self.act_funcs) # z = act(z) + z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) new_values = values.at[i].set(z) return new_values diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py index d82d7bf..632eea7 100644 --- a/algorithm/neat/gene/recurrent.py +++ b/algorithm/neat/gene/recurrent.py @@ -48,9 +48,9 @@ class RecurrentGene(NormalGene): def body_func(i, values): values = values.at[input_idx].set(inputs) nodes_ins = values * weights.T - values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins) + values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins) values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z) + values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z) return values vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals) diff --git a/config/config.py b/config/config.py index 87c55c2..39f9d8f 100644 --- a/config/config.py +++ b/config/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass - +from utils import Act, Agg @dataclass(frozen=True) class BasicConfig: @@ -68,8 +68,8 @@ class NeatConfig: class HyperNeatConfig: below_threshold: float = 0.2 max_weight: float = 3 - activation: str = "sigmoid" - aggregation: str = "sum" + activation: callable = Act.sigmoid + aggregation: callable = Agg.sum activate_times: int = 5 inputs: int = 2 outputs: int = 1 diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py index 23fc389..cfd23f1 100644 --- a/examples/func_fit/xor_hyperneat.py +++ b/examples/func_fit/xor_hyperneat.py @@ -3,6 +3,7 @@ from pipeline import Pipeline from algorithm.neat import NormalGene, NormalGeneConfig from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig from problem.func_fit import XOR3d, FuncFitConfig +from utils import Act if __name__ == '__main__': @@ -27,8 +28,8 @@ if __name__ == '__main__': input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)), ), gene=NormalGeneConfig( - activation_default='tanh', - activation_options=('tanh', ), + activation_default=Act.tanh, + activation_options=(Act.tanh, ), ), problem=FuncFitConfig() ) diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py index 2eed951..d100fd8 100644 --- a/examples/func_fit/xor_recurrent.py +++ b/examples/func_fit/xor_recurrent.py @@ -36,5 +36,6 @@ if __name__ == '__main__': algorithm = NEAT(config, RecurrentGene) pipeline = Pipeline(config, algorithm, XOR3d) state = pipeline.setup() + pipeline.pre_compile(state) state, best = pipeline.auto_run(state) pipeline.show(state, best) diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index 3ce61c5..bc18d86 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -19,8 +19,8 @@ def example_conf1(): outputs=1, ), gene=NormalGeneConfig( - activation_default='sigmoid', - activation_options=('sigmoid',), + activation_default=Act.sigmoid, + activation_options=(Act.sigmoid,), ), problem=GymNaxConfig( env_name='CartPole-v1', @@ -41,8 +41,8 @@ def example_conf2(): outputs=1, ), gene=NormalGeneConfig( - activation_default='tanh', - activation_options=('tanh',), + activation_default=Act.tanh, + activation_options=(Act.tanh,), ), problem=GymNaxConfig( env_name='CartPole-v1', @@ -63,8 +63,8 @@ def example_conf3(): outputs=2, ), gene=NormalGeneConfig( - activation_default='tanh', - activation_options=('tanh',), + activation_default=Act.tanh, + activation_options=(Act.tanh,), ), problem=GymNaxConfig( env_name='CartPole-v1', @@ -80,5 +80,5 @@ if __name__ == '__main__': algorithm = NEAT(conf, NormalGene) pipeline = Pipeline(conf, algorithm, GymNaxEnv) state = pipeline.setup() + pipeline.pre_compile(state) state, best = pipeline.auto_run(state) - pipeline.show(state, best) diff --git a/utils/__init__.py b/utils/__init__.py index 9820a71..f8237b0 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,35 +1,4 @@ -from .activation import Activation, act -from .aggregation import Aggregation, agg +from .activation import Act, act +from .aggregation import Agg, agg from .tools import * -from .graph import * - -Activation.name2func = { - 'sigmoid': Activation.sigmoid_act, - 'tanh': Activation.tanh_act, - 'sin': Activation.sin_act, - 'gauss': Activation.gauss_act, - 'relu': Activation.relu_act, - 'elu': Activation.elu_act, - 'lelu': Activation.lelu_act, - 'selu': Activation.selu_act, - 'softplus': Activation.softplus_act, - 'identity': Activation.identity_act, - 'clamped': Activation.clamped_act, - 'inv': Activation.inv_act, - 'log': Activation.log_act, - 'exp': Activation.exp_act, - 'abs': Activation.abs_act, - 'hat': Activation.hat_act, - 'square': Activation.square_act, - 'cube': Activation.cube_act, -} - -Aggregation.name2func = { - 'sum': Aggregation.sum_agg, - 'product': Aggregation.product_agg, - 'max': Aggregation.max_agg, - 'min': Aggregation.min_agg, - 'maxabs': Aggregation.maxabs_agg, - 'median': Aggregation.median_agg, - 'mean': Aggregation.mean_agg, -} +from .graph import * \ No newline at end of file diff --git a/utils/activation.py b/utils/activation.py index f580d57..9795c31 100644 --- a/utils/activation.py +++ b/utils/activation.py @@ -2,90 +2,89 @@ import jax import jax.numpy as jnp -class Activation: - name2func = {} +class Act: @staticmethod - def sigmoid_act(z): + def sigmoid(z): z = jnp.clip(z * 5, -60, 60) return 1 / (1 + jnp.exp(-z)) @staticmethod - def tanh_act(z): + def tanh(z): z = jnp.clip(z * 2.5, -60, 60) return jnp.tanh(z) @staticmethod - def sin_act(z): + def sin(z): z = jnp.clip(z * 5, -60, 60) return jnp.sin(z) @staticmethod - def gauss_act(z): + def gauss(z): z = jnp.clip(z * 5, -3.4, 3.4) return jnp.exp(-z ** 2) @staticmethod - def relu_act(z): + def relu(z): return jnp.maximum(z, 0) @staticmethod - def elu_act(z): + def elu(z): return jnp.where(z > 0, z, jnp.exp(z) - 1) @staticmethod - def lelu_act(z): + def lelu(z): leaky = 0.005 return jnp.where(z > 0, z, leaky * z) @staticmethod - def selu_act(z): + def selu(z): lam = 1.0507009873554804934193349852946 alpha = 1.6732632423543772848170429916717 return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) @staticmethod - def softplus_act(z): + def softplus(z): z = jnp.clip(z * 5, -60, 60) return 0.2 * jnp.log(1 + jnp.exp(z)) @staticmethod - def identity_act(z): + def identity(z): return z @staticmethod - def clamped_act(z): + def clamped(z): return jnp.clip(z, -1, 1) @staticmethod - def inv_act(z): + def inv(z): z = jnp.maximum(z, 1e-7) return 1 / z @staticmethod - def log_act(z): + def log(z): z = jnp.maximum(z, 1e-7) return jnp.log(z) @staticmethod - def exp_act(z): + def exp(z): z = jnp.clip(z, -60, 60) return jnp.exp(z) @staticmethod - def abs_act(z): + def abs(z): return jnp.abs(z) @staticmethod - def hat_act(z): + def hat(z): return jnp.maximum(0, 1 - jnp.abs(z)) @staticmethod - def square_act(z): + def square(z): return z ** 2 @staticmethod - def cube_act(z): + def cube(z): return z ** 3 diff --git a/utils/aggregation.py b/utils/aggregation.py index 86c686a..4b94fe4 100644 --- a/utils/aggregation.py +++ b/utils/aggregation.py @@ -2,38 +2,37 @@ import jax import jax.numpy as jnp -class Aggregation: - name2func = {} +class Agg: @staticmethod - def sum_agg(z): + def sum(z): z = jnp.where(jnp.isnan(z), 0, z) return jnp.sum(z, axis=0) @staticmethod - def product_agg(z): + def product(z): z = jnp.where(jnp.isnan(z), 1, z) return jnp.prod(z, axis=0) @staticmethod - def max_agg(z): + def max(z): z = jnp.where(jnp.isnan(z), -jnp.inf, z) return jnp.max(z, axis=0) @staticmethod - def min_agg(z): + def min(z): z = jnp.where(jnp.isnan(z), jnp.inf, z) return jnp.min(z, axis=0) @staticmethod - def maxabs_agg(z): + def maxabs(z): z = jnp.where(jnp.isnan(z), 0, z) abs_z = jnp.abs(z) max_abs_index = jnp.argmax(abs_z) return z[max_abs_index] @staticmethod - def median_agg(z): + def median(z): n = jnp.sum(~jnp.isnan(z), axis=0) z = jnp.sort(z) # sort @@ -44,7 +43,7 @@ class Aggregation: return median @staticmethod - def mean_agg(z): + def mean(z): aux = jnp.where(jnp.isnan(z), 0, z) valid_values_sum = jnp.sum(aux, axis=0) valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)