change str in config (act, agg) from str to callable

This commit is contained in:
wls2002
2023-08-05 03:03:02 +08:00
parent 0e44b13291
commit af54db3b12
10 changed files with 55 additions and 101 deletions

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
@@ -23,12 +23,12 @@ class NormalGeneConfig(GeneConfig):
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple = ('sigmoid',)
activation_default: callable = Act.sigmoid
activation_options: Tuple = (Act.sigmoid, )
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple = ('sum',)
aggregation_default: callable = Agg.sum
aggregation_options: Tuple = (Agg.sum, )
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
@@ -49,18 +49,8 @@ class NormalGeneConfig(GeneConfig):
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
@@ -68,8 +58,6 @@ class NormalGene(Gene):
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update(
@@ -170,9 +158,9 @@ class NormalGene(Gene):
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values

View File

@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)