change str in config (act, agg) from str to callable

This commit is contained in:
wls2002
2023-08-05 03:03:02 +08:00
parent 0e44b13291
commit af54db3b12
10 changed files with 55 additions and 101 deletions

View File

@@ -6,7 +6,7 @@ import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Activation, Aggregation
from utils import Act, Agg
from .substrate import analysis_substrate
from algorithm import NEAT
@@ -90,10 +90,7 @@ class HyperNEATGene:
@staticmethod
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
nodes, weights = transformed

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple = ('sigmoid',)
activation_default: callable = Act.sigmoid
activation_options: Tuple = (Act.sigmoid, )
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple = ('sum',)
aggregation_default: callable = Agg.sum
aggregation_options: Tuple = (Agg.sum, )
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
@@ -68,8 +58,6 @@ class NormalGene(Gene):
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update(
@@ -170,9 +158,9 @@ class NormalGene(Gene):
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values

View File

@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from utils import Act, Agg
@dataclass(frozen=True)
class BasicConfig:
@@ -68,8 +68,8 @@ class NeatConfig:
class HyperNeatConfig:
below_threshold: float = 0.2
max_weight: float = 3
activation: str = "sigmoid"
aggregation: str = "sum"
activation: callable = Act.sigmoid
aggregation: callable = Agg.sum
activate_times: int = 5
inputs: int = 2
outputs: int = 1

View File

@@ -3,6 +3,7 @@ from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig
from utils import Act
if __name__ == '__main__':
@@ -27,8 +28,8 @@ if __name__ == '__main__':
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh', ),
activation_default=Act.tanh,
activation_options=(Act.tanh, ),
),
problem=FuncFitConfig()
)

View File

@@ -36,5 +36,6 @@ if __name__ == '__main__':
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -19,8 +19,8 @@ def example_conf1():
outputs=1,
),
gene=NormalGeneConfig(
activation_default='sigmoid',
activation_options=('sigmoid',),
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
@@ -41,8 +41,8 @@ def example_conf2():
outputs=1,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
@@ -63,8 +63,8 @@ def example_conf3():
outputs=2,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
@@ -80,5 +80,5 @@ if __name__ == '__main__':
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,35 +1,4 @@
from .activation import Activation, act
from .aggregation import Aggregation, agg
from .activation import Act, act
from .aggregation import Agg, agg
from .tools import *
from .graph import *
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}

View File

@@ -2,90 +2,89 @@ import jax
import jax.numpy as jnp
class Activation:
name2func = {}
class Act:
@staticmethod
def sigmoid_act(z):
def sigmoid(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh_act(z):
def tanh(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin_act(z):
def sin(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss_act(z):
def gauss(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu_act(z):
def relu(z):
return jnp.maximum(z, 0)
@staticmethod
def elu_act(z):
def elu(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu_act(z):
def lelu(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu_act(z):
def selu(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus_act(z):
def softplus(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity_act(z):
def identity(z):
return z
@staticmethod
def clamped_act(z):
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv_act(z):
def inv(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@staticmethod
def log_act(z):
def log(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp_act(z):
def exp(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@staticmethod
def abs_act(z):
def abs(z):
return jnp.abs(z)
@staticmethod
def hat_act(z):
def hat(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square_act(z):
def square(z):
return z ** 2
@staticmethod
def cube_act(z):
def cube(z):
return z ** 3

View File

@@ -2,38 +2,37 @@ import jax
import jax.numpy as jnp
class Aggregation:
name2func = {}
class Agg:
@staticmethod
def sum_agg(z):
def sum(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product_agg(z):
def product(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max_agg(z):
def max(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min_agg(z):
def min(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs_agg(z):
def maxabs(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median_agg(z):
def median(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
@@ -44,7 +43,7 @@ class Aggregation:
return median
@staticmethod
def mean_agg(z):
def mean(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)