change str in config (act, agg) from str to callable
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from config import Config, HyperNeatConfig
|
from config import Config, HyperNeatConfig
|
||||||
from core import Algorithm, Substrate, State, Genome, Gene
|
from core import Algorithm, Substrate, State, Genome, Gene
|
||||||
from utils import Activation, Aggregation
|
from utils import Act, Agg
|
||||||
from .substrate import analysis_substrate
|
from .substrate import analysis_substrate
|
||||||
from algorithm import NEAT
|
from algorithm import NEAT
|
||||||
|
|
||||||
@@ -90,10 +90,7 @@ class HyperNEATGene:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
|
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
|
||||||
act = Activation.name2func[config.activation]
|
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
|
||||||
agg = Aggregation.name2func[config.aggregation]
|
|
||||||
|
|
||||||
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
|
|
||||||
|
|
||||||
nodes, weights = transformed
|
nodes, weights = transformed
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
|||||||
|
|
||||||
from config import GeneConfig
|
from config import GeneConfig
|
||||||
from core import Gene, Genome, State
|
from core import Gene, Genome, State
|
||||||
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
|
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
|
|||||||
response_mutate_rate: float = 0.7
|
response_mutate_rate: float = 0.7
|
||||||
response_replace_rate: float = 0.1
|
response_replace_rate: float = 0.1
|
||||||
|
|
||||||
activation_default: str = 'sigmoid'
|
activation_default: callable = Act.sigmoid
|
||||||
activation_options: Tuple = ('sigmoid',)
|
activation_options: Tuple = (Act.sigmoid, )
|
||||||
activation_replace_rate: float = 0.1
|
activation_replace_rate: float = 0.1
|
||||||
|
|
||||||
aggregation_default: str = 'sum'
|
aggregation_default: callable = Agg.sum
|
||||||
aggregation_options: Tuple = ('sum',)
|
aggregation_options: Tuple = (Agg.sum, )
|
||||||
aggregation_replace_rate: float = 0.1
|
aggregation_replace_rate: float = 0.1
|
||||||
|
|
||||||
weight_init_mean: float = 0.0
|
weight_init_mean: float = 0.0
|
||||||
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
|
|||||||
assert self.response_replace_rate >= 0.0
|
assert self.response_replace_rate >= 0.0
|
||||||
|
|
||||||
assert self.activation_default == self.activation_options[0]
|
assert self.activation_default == self.activation_options[0]
|
||||||
|
|
||||||
for name in self.activation_options:
|
|
||||||
assert name in Activation.name2func, f"Activation function: {name} not found"
|
|
||||||
|
|
||||||
assert self.aggregation_default == self.aggregation_options[0]
|
assert self.aggregation_default == self.aggregation_options[0]
|
||||||
|
|
||||||
assert self.aggregation_default in Aggregation.name2func, \
|
|
||||||
f"Aggregation function: {self.aggregation_default} not found"
|
|
||||||
|
|
||||||
for name in self.aggregation_options:
|
|
||||||
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
|
|
||||||
|
|
||||||
|
|
||||||
class NormalGene(Gene):
|
class NormalGene(Gene):
|
||||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||||
@@ -68,8 +58,6 @@ class NormalGene(Gene):
|
|||||||
|
|
||||||
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
|
||||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
|
||||||
|
|
||||||
def setup(self, state: State = State()):
|
def setup(self, state: State = State()):
|
||||||
return state.update(
|
return state.update(
|
||||||
@@ -170,9 +158,9 @@ class NormalGene(Gene):
|
|||||||
|
|
||||||
def hit():
|
def hit():
|
||||||
ins = values * weights[:, i]
|
ins = values * weights[:, i]
|
||||||
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
|
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||||
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
|
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||||
|
|
||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|||||||
@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
|
|||||||
def body_func(i, values):
|
def body_func(i, values):
|
||||||
values = values.at[input_idx].set(inputs)
|
values = values.at[input_idx].set(inputs)
|
||||||
nodes_ins = values * weights.T
|
nodes_ins = values * weights.T
|
||||||
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
|
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
|
||||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||||
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
|
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from utils import Act, Agg
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BasicConfig:
|
class BasicConfig:
|
||||||
@@ -68,8 +68,8 @@ class NeatConfig:
|
|||||||
class HyperNeatConfig:
|
class HyperNeatConfig:
|
||||||
below_threshold: float = 0.2
|
below_threshold: float = 0.2
|
||||||
max_weight: float = 3
|
max_weight: float = 3
|
||||||
activation: str = "sigmoid"
|
activation: callable = Act.sigmoid
|
||||||
aggregation: str = "sum"
|
aggregation: callable = Agg.sum
|
||||||
activate_times: int = 5
|
activate_times: int = 5
|
||||||
inputs: int = 2
|
inputs: int = 2
|
||||||
outputs: int = 1
|
outputs: int = 1
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from pipeline import Pipeline
|
|||||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||||
from problem.func_fit import XOR3d, FuncFitConfig
|
from problem.func_fit import XOR3d, FuncFitConfig
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -27,8 +28,8 @@ if __name__ == '__main__':
|
|||||||
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
||||||
),
|
),
|
||||||
gene=NormalGeneConfig(
|
gene=NormalGeneConfig(
|
||||||
activation_default='tanh',
|
activation_default=Act.tanh,
|
||||||
activation_options=('tanh', ),
|
activation_options=(Act.tanh, ),
|
||||||
),
|
),
|
||||||
problem=FuncFitConfig()
|
problem=FuncFitConfig()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,5 +36,6 @@ if __name__ == '__main__':
|
|||||||
algorithm = NEAT(config, RecurrentGene)
|
algorithm = NEAT(config, RecurrentGene)
|
||||||
pipeline = Pipeline(config, algorithm, XOR3d)
|
pipeline = Pipeline(config, algorithm, XOR3d)
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
|
pipeline.pre_compile(state)
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
pipeline.show(state, best)
|
pipeline.show(state, best)
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ def example_conf1():
|
|||||||
outputs=1,
|
outputs=1,
|
||||||
),
|
),
|
||||||
gene=NormalGeneConfig(
|
gene=NormalGeneConfig(
|
||||||
activation_default='sigmoid',
|
activation_default=Act.sigmoid,
|
||||||
activation_options=('sigmoid',),
|
activation_options=(Act.sigmoid,),
|
||||||
),
|
),
|
||||||
problem=GymNaxConfig(
|
problem=GymNaxConfig(
|
||||||
env_name='CartPole-v1',
|
env_name='CartPole-v1',
|
||||||
@@ -41,8 +41,8 @@ def example_conf2():
|
|||||||
outputs=1,
|
outputs=1,
|
||||||
),
|
),
|
||||||
gene=NormalGeneConfig(
|
gene=NormalGeneConfig(
|
||||||
activation_default='tanh',
|
activation_default=Act.tanh,
|
||||||
activation_options=('tanh',),
|
activation_options=(Act.tanh,),
|
||||||
),
|
),
|
||||||
problem=GymNaxConfig(
|
problem=GymNaxConfig(
|
||||||
env_name='CartPole-v1',
|
env_name='CartPole-v1',
|
||||||
@@ -63,8 +63,8 @@ def example_conf3():
|
|||||||
outputs=2,
|
outputs=2,
|
||||||
),
|
),
|
||||||
gene=NormalGeneConfig(
|
gene=NormalGeneConfig(
|
||||||
activation_default='tanh',
|
activation_default=Act.tanh,
|
||||||
activation_options=('tanh',),
|
activation_options=(Act.tanh,),
|
||||||
),
|
),
|
||||||
problem=GymNaxConfig(
|
problem=GymNaxConfig(
|
||||||
env_name='CartPole-v1',
|
env_name='CartPole-v1',
|
||||||
@@ -80,5 +80,5 @@ if __name__ == '__main__':
|
|||||||
algorithm = NEAT(conf, NormalGene)
|
algorithm = NEAT(conf, NormalGene)
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
|
pipeline.pre_compile(state)
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
pipeline.show(state, best)
|
|
||||||
|
|||||||
@@ -1,35 +1,4 @@
|
|||||||
from .activation import Activation, act
|
from .activation import Act, act
|
||||||
from .aggregation import Aggregation, agg
|
from .aggregation import Agg, agg
|
||||||
from .tools import *
|
from .tools import *
|
||||||
from .graph import *
|
from .graph import *
|
||||||
|
|
||||||
Activation.name2func = {
|
|
||||||
'sigmoid': Activation.sigmoid_act,
|
|
||||||
'tanh': Activation.tanh_act,
|
|
||||||
'sin': Activation.sin_act,
|
|
||||||
'gauss': Activation.gauss_act,
|
|
||||||
'relu': Activation.relu_act,
|
|
||||||
'elu': Activation.elu_act,
|
|
||||||
'lelu': Activation.lelu_act,
|
|
||||||
'selu': Activation.selu_act,
|
|
||||||
'softplus': Activation.softplus_act,
|
|
||||||
'identity': Activation.identity_act,
|
|
||||||
'clamped': Activation.clamped_act,
|
|
||||||
'inv': Activation.inv_act,
|
|
||||||
'log': Activation.log_act,
|
|
||||||
'exp': Activation.exp_act,
|
|
||||||
'abs': Activation.abs_act,
|
|
||||||
'hat': Activation.hat_act,
|
|
||||||
'square': Activation.square_act,
|
|
||||||
'cube': Activation.cube_act,
|
|
||||||
}
|
|
||||||
|
|
||||||
Aggregation.name2func = {
|
|
||||||
'sum': Aggregation.sum_agg,
|
|
||||||
'product': Aggregation.product_agg,
|
|
||||||
'max': Aggregation.max_agg,
|
|
||||||
'min': Aggregation.min_agg,
|
|
||||||
'maxabs': Aggregation.maxabs_agg,
|
|
||||||
'median': Aggregation.median_agg,
|
|
||||||
'mean': Aggregation.mean_agg,
|
|
||||||
}
|
|
||||||
@@ -2,90 +2,89 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
class Activation:
|
class Act:
|
||||||
name2func = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sigmoid_act(z):
|
def sigmoid(z):
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
return 1 / (1 + jnp.exp(-z))
|
return 1 / (1 + jnp.exp(-z))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tanh_act(z):
|
def tanh(z):
|
||||||
z = jnp.clip(z * 2.5, -60, 60)
|
z = jnp.clip(z * 2.5, -60, 60)
|
||||||
return jnp.tanh(z)
|
return jnp.tanh(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sin_act(z):
|
def sin(z):
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
return jnp.sin(z)
|
return jnp.sin(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gauss_act(z):
|
def gauss(z):
|
||||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||||
return jnp.exp(-z ** 2)
|
return jnp.exp(-z ** 2)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def relu_act(z):
|
def relu(z):
|
||||||
return jnp.maximum(z, 0)
|
return jnp.maximum(z, 0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def elu_act(z):
|
def elu(z):
|
||||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lelu_act(z):
|
def lelu(z):
|
||||||
leaky = 0.005
|
leaky = 0.005
|
||||||
return jnp.where(z > 0, z, leaky * z)
|
return jnp.where(z > 0, z, leaky * z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def selu_act(z):
|
def selu(z):
|
||||||
lam = 1.0507009873554804934193349852946
|
lam = 1.0507009873554804934193349852946
|
||||||
alpha = 1.6732632423543772848170429916717
|
alpha = 1.6732632423543772848170429916717
|
||||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def softplus_act(z):
|
def softplus(z):
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def identity_act(z):
|
def identity(z):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clamped_act(z):
|
def clamped(z):
|
||||||
return jnp.clip(z, -1, 1)
|
return jnp.clip(z, -1, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def inv_act(z):
|
def inv(z):
|
||||||
z = jnp.maximum(z, 1e-7)
|
z = jnp.maximum(z, 1e-7)
|
||||||
return 1 / z
|
return 1 / z
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def log_act(z):
|
def log(z):
|
||||||
z = jnp.maximum(z, 1e-7)
|
z = jnp.maximum(z, 1e-7)
|
||||||
return jnp.log(z)
|
return jnp.log(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exp_act(z):
|
def exp(z):
|
||||||
z = jnp.clip(z, -60, 60)
|
z = jnp.clip(z, -60, 60)
|
||||||
return jnp.exp(z)
|
return jnp.exp(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def abs_act(z):
|
def abs(z):
|
||||||
return jnp.abs(z)
|
return jnp.abs(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hat_act(z):
|
def hat(z):
|
||||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def square_act(z):
|
def square(z):
|
||||||
return z ** 2
|
return z ** 2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cube_act(z):
|
def cube(z):
|
||||||
return z ** 3
|
return z ** 3
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,38 +2,37 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
class Aggregation:
|
class Agg:
|
||||||
name2func = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sum_agg(z):
|
def sum(z):
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
return jnp.sum(z, axis=0)
|
return jnp.sum(z, axis=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def product_agg(z):
|
def product(z):
|
||||||
z = jnp.where(jnp.isnan(z), 1, z)
|
z = jnp.where(jnp.isnan(z), 1, z)
|
||||||
return jnp.prod(z, axis=0)
|
return jnp.prod(z, axis=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def max_agg(z):
|
def max(z):
|
||||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||||
return jnp.max(z, axis=0)
|
return jnp.max(z, axis=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def min_agg(z):
|
def min(z):
|
||||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||||
return jnp.min(z, axis=0)
|
return jnp.min(z, axis=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maxabs_agg(z):
|
def maxabs(z):
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
abs_z = jnp.abs(z)
|
abs_z = jnp.abs(z)
|
||||||
max_abs_index = jnp.argmax(abs_z)
|
max_abs_index = jnp.argmax(abs_z)
|
||||||
return z[max_abs_index]
|
return z[max_abs_index]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def median_agg(z):
|
def median(z):
|
||||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||||
|
|
||||||
z = jnp.sort(z) # sort
|
z = jnp.sort(z) # sort
|
||||||
@@ -44,7 +43,7 @@ class Aggregation:
|
|||||||
return median
|
return median
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mean_agg(z):
|
def mean(z):
|
||||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||||
valid_values_sum = jnp.sum(aux, axis=0)
|
valid_values_sum = jnp.sum(aux, axis=0)
|
||||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user