change str in config (act, agg) from str to callable

This commit is contained in:
wls2002
2023-08-05 03:03:02 +08:00
parent 0e44b13291
commit af54db3b12
10 changed files with 55 additions and 101 deletions

View File

@@ -6,7 +6,7 @@ import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Activation, Aggregation
from utils import Act, Agg
from .substrate import analysis_substrate
from algorithm import NEAT
@@ -90,10 +90,7 @@ class HyperNEATGene:
@staticmethod
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
nodes, weights = transformed

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple = ('sigmoid',)
activation_default: callable = Act.sigmoid
activation_options: Tuple = (Act.sigmoid, )
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple = ('sum',)
aggregation_default: callable = Agg.sum
aggregation_options: Tuple = (Agg.sum, )
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
@@ -68,8 +58,6 @@ class NormalGene(Gene):
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update(
@@ -170,9 +158,9 @@ class NormalGene(Gene):
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values

View File

@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)