change str in config (act, agg) from str to callable
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from config import Config, HyperNeatConfig
|
||||
from core import Algorithm, Substrate, State, Genome, Gene
|
||||
from utils import Activation, Aggregation
|
||||
from utils import Act, Agg
|
||||
from .substrate import analysis_substrate
|
||||
from algorithm import NEAT
|
||||
|
||||
@@ -90,10 +90,7 @@ class HyperNEATGene:
|
||||
|
||||
@staticmethod
|
||||
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
|
||||
act = Activation.name2func[config.activation]
|
||||
agg = Aggregation.name2func[config.aggregation]
|
||||
|
||||
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
|
||||
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
|
||||
|
||||
nodes, weights = transformed
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
||||
|
||||
from config import GeneConfig
|
||||
from core import Gene, Genome, State
|
||||
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
|
||||
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
|
||||
response_mutate_rate: float = 0.7
|
||||
response_replace_rate: float = 0.1
|
||||
|
||||
activation_default: str = 'sigmoid'
|
||||
activation_options: Tuple = ('sigmoid',)
|
||||
activation_default: callable = Act.sigmoid
|
||||
activation_options: Tuple = (Act.sigmoid, )
|
||||
activation_replace_rate: float = 0.1
|
||||
|
||||
aggregation_default: str = 'sum'
|
||||
aggregation_options: Tuple = ('sum',)
|
||||
aggregation_default: callable = Agg.sum
|
||||
aggregation_options: Tuple = (Agg.sum, )
|
||||
aggregation_replace_rate: float = 0.1
|
||||
|
||||
weight_init_mean: float = 0.0
|
||||
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
|
||||
assert self.response_replace_rate >= 0.0
|
||||
|
||||
assert self.activation_default == self.activation_options[0]
|
||||
|
||||
for name in self.activation_options:
|
||||
assert name in Activation.name2func, f"Activation function: {name} not found"
|
||||
|
||||
assert self.aggregation_default == self.aggregation_options[0]
|
||||
|
||||
assert self.aggregation_default in Aggregation.name2func, \
|
||||
f"Aggregation function: {self.aggregation_default} not found"
|
||||
|
||||
for name in self.aggregation_options:
|
||||
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
|
||||
|
||||
|
||||
class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
@@ -68,8 +58,6 @@ class NormalGene(Gene):
|
||||
|
||||
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||
self.config = config
|
||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state.update(
|
||||
@@ -170,9 +158,9 @@ class NormalGene(Gene):
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
|
||||
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
|
||||
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
|
||||
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
|
||||
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
||||
|
||||
Reference in New Issue
Block a user