change str in config (act, agg) from str to callable
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from config import Config, HyperNeatConfig
|
||||
from core import Algorithm, Substrate, State, Genome, Gene
|
||||
from utils import Activation, Aggregation
|
||||
from utils import Act, Agg
|
||||
from .substrate import analysis_substrate
|
||||
from algorithm import NEAT
|
||||
|
||||
@@ -90,10 +90,7 @@ class HyperNEATGene:
|
||||
|
||||
@staticmethod
|
||||
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
|
||||
act = Activation.name2func[config.activation]
|
||||
agg = Aggregation.name2func[config.aggregation]
|
||||
|
||||
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
|
||||
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
|
||||
|
||||
nodes, weights = transformed
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
||||
|
||||
from config import GeneConfig
|
||||
from core import Gene, Genome, State
|
||||
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
|
||||
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
|
||||
response_mutate_rate: float = 0.7
|
||||
response_replace_rate: float = 0.1
|
||||
|
||||
activation_default: str = 'sigmoid'
|
||||
activation_options: Tuple = ('sigmoid',)
|
||||
activation_default: callable = Act.sigmoid
|
||||
activation_options: Tuple = (Act.sigmoid, )
|
||||
activation_replace_rate: float = 0.1
|
||||
|
||||
aggregation_default: str = 'sum'
|
||||
aggregation_options: Tuple = ('sum',)
|
||||
aggregation_default: callable = Agg.sum
|
||||
aggregation_options: Tuple = (Agg.sum, )
|
||||
aggregation_replace_rate: float = 0.1
|
||||
|
||||
weight_init_mean: float = 0.0
|
||||
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
|
||||
assert self.response_replace_rate >= 0.0
|
||||
|
||||
assert self.activation_default == self.activation_options[0]
|
||||
|
||||
for name in self.activation_options:
|
||||
assert name in Activation.name2func, f"Activation function: {name} not found"
|
||||
|
||||
assert self.aggregation_default == self.aggregation_options[0]
|
||||
|
||||
assert self.aggregation_default in Aggregation.name2func, \
|
||||
f"Aggregation function: {self.aggregation_default} not found"
|
||||
|
||||
for name in self.aggregation_options:
|
||||
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
|
||||
|
||||
|
||||
class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
@@ -68,8 +58,6 @@ class NormalGene(Gene):
|
||||
|
||||
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||
self.config = config
|
||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state.update(
|
||||
@@ -170,9 +158,9 @@ class NormalGene(Gene):
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
|
||||
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
|
||||
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
|
||||
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
|
||||
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from utils import Act, Agg
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BasicConfig:
|
||||
@@ -68,8 +68,8 @@ class NeatConfig:
|
||||
class HyperNeatConfig:
|
||||
below_threshold: float = 0.2
|
||||
max_weight: float = 3
|
||||
activation: str = "sigmoid"
|
||||
aggregation: str = "sum"
|
||||
activation: callable = Act.sigmoid
|
||||
aggregation: callable = Agg.sum
|
||||
activate_times: int = 5
|
||||
inputs: int = 2
|
||||
outputs: int = 1
|
||||
|
||||
@@ -3,6 +3,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
from utils import Act
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -27,8 +28,8 @@ if __name__ == '__main__':
|
||||
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh', ),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh, ),
|
||||
),
|
||||
problem=FuncFitConfig()
|
||||
)
|
||||
|
||||
@@ -36,5 +36,6 @@ if __name__ == '__main__':
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm, XOR3d)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
pipeline.show(state, best)
|
||||
|
||||
@@ -19,8 +19,8 @@ def example_conf1():
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='sigmoid',
|
||||
activation_options=('sigmoid',),
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -41,8 +41,8 @@ def example_conf2():
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh',),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -63,8 +63,8 @@ def example_conf3():
|
||||
outputs=2,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh',),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -80,5 +80,5 @@ if __name__ == '__main__':
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
pipeline.show(state, best)
|
||||
|
||||
@@ -1,35 +1,4 @@
|
||||
from .activation import Activation, act
|
||||
from .aggregation import Aggregation, agg
|
||||
from .activation import Act, act
|
||||
from .aggregation import Agg, agg
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
|
||||
Activation.name2func = {
|
||||
'sigmoid': Activation.sigmoid_act,
|
||||
'tanh': Activation.tanh_act,
|
||||
'sin': Activation.sin_act,
|
||||
'gauss': Activation.gauss_act,
|
||||
'relu': Activation.relu_act,
|
||||
'elu': Activation.elu_act,
|
||||
'lelu': Activation.lelu_act,
|
||||
'selu': Activation.selu_act,
|
||||
'softplus': Activation.softplus_act,
|
||||
'identity': Activation.identity_act,
|
||||
'clamped': Activation.clamped_act,
|
||||
'inv': Activation.inv_act,
|
||||
'log': Activation.log_act,
|
||||
'exp': Activation.exp_act,
|
||||
'abs': Activation.abs_act,
|
||||
'hat': Activation.hat_act,
|
||||
'square': Activation.square_act,
|
||||
'cube': Activation.cube_act,
|
||||
}
|
||||
|
||||
Aggregation.name2func = {
|
||||
'sum': Aggregation.sum_agg,
|
||||
'product': Aggregation.product_agg,
|
||||
'max': Aggregation.max_agg,
|
||||
'min': Aggregation.min_agg,
|
||||
'maxabs': Aggregation.maxabs_agg,
|
||||
'median': Aggregation.median_agg,
|
||||
'mean': Aggregation.mean_agg,
|
||||
}
|
||||
|
||||
@@ -2,90 +2,89 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Activation:
|
||||
name2func = {}
|
||||
class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid_act(z):
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
@staticmethod
|
||||
def tanh_act(z):
|
||||
def tanh(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin_act(z):
|
||||
def sin(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def gauss_act(z):
|
||||
def gauss(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
@staticmethod
|
||||
def relu_act(z):
|
||||
def relu(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
@staticmethod
|
||||
def elu_act(z):
|
||||
def elu(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
@staticmethod
|
||||
def lelu_act(z):
|
||||
def lelu(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def selu_act(z):
|
||||
def selu(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
@staticmethod
|
||||
def softplus_act(z):
|
||||
def softplus(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
@staticmethod
|
||||
def identity_act(z):
|
||||
def identity(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def clamped_act(z):
|
||||
def clamped(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def inv_act(z):
|
||||
def inv(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log_act(z):
|
||||
def log(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp_act(z):
|
||||
def exp(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def abs_act(z):
|
||||
def abs(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
@staticmethod
|
||||
def hat_act(z):
|
||||
def hat(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
@staticmethod
|
||||
def square_act(z):
|
||||
def square(z):
|
||||
return z ** 2
|
||||
|
||||
@staticmethod
|
||||
def cube_act(z):
|
||||
def cube(z):
|
||||
return z ** 3
|
||||
|
||||
|
||||
|
||||
@@ -2,38 +2,37 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Aggregation:
|
||||
name2func = {}
|
||||
class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum_agg(z):
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def product_agg(z):
|
||||
def product(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def max_agg(z):
|
||||
def max(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def min_agg(z):
|
||||
def min(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def maxabs_agg(z):
|
||||
def maxabs(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median_agg(z):
|
||||
def median(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
@@ -44,7 +43,7 @@ class Aggregation:
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean_agg(z):
|
||||
def mean(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
Reference in New Issue
Block a user