From 625c261a49403ca87f1951c50893acb16d1ab565 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 25 May 2024 16:13:41 +0800 Subject: [PATCH] make fully stateful in module gene. --- tensorneat/algorithm/neat/gene/base.py | 16 +++++++++++----- tensorneat/algorithm/neat/gene/conn/base.py | 2 +- .../algorithm/neat/gene/conn/default.py | 19 ++++--------------- tensorneat/algorithm/neat/gene/node/base.py | 2 +- .../algorithm/neat/gene/node/default.py | 18 ++++-------------- 5 files changed, 21 insertions(+), 36 deletions(-) diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 1430171..1110074 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -1,3 +1,6 @@ +from utils import State + + class BaseGene: "Base class for node genes or connection genes." fixed_attrs = [] @@ -6,18 +9,21 @@ class BaseGene: def __init__(self): pass - def new_custom_attrs(self): + def setup(self, state=State()): + return state + + def new_attrs(self, state, key): raise NotImplementedError - def mutate(self, randkey, gene): + def mutate(self, state, key, gene): raise NotImplementedError - def distance(self, gene1, gene2): + def distance(self, state, gene1, gene2): raise NotImplementedError - def forward(self, attrs, inputs): + def forward(self, state, attrs, inputs): raise NotImplementedError @property def length(self): - return len(self.fixed_attrs) + len(self.custom_attrs) \ No newline at end of file + return len(self.fixed_attrs) + len(self.custom_attrs) diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index b7b3bdc..17a67fc 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -8,5 +8,5 @@ class BaseConnGene(BaseGene): def __init__(self): super().__init__() - def forward(self, attrs, inputs): + def forward(self, state, attrs, inputs): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index cf1b18e..d8834ed 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -24,21 +24,10 @@ class DefaultConnGene(BaseConnGene): self.weight_mutate_rate = weight_mutate_rate self.weight_replace_rate = weight_replace_rate - def new_custom_attrs(self): + def new_attrs(self, state, key): return jnp.array([self.weight_init_mean]) - - def new_random_attrs(self, key): - return jnp.array([mutate_float(key, - self.weight_init_mean, - self.weight_init_mean, - 1.0, - 0, - 0, - 1.0, - ) - ]) - def mutate(self, key, conn): + def mutate(self, state, key, conn): input_index = conn[0] output_index = conn[1] enabled = conn[2] @@ -53,9 +42,9 @@ class DefaultConnGene(BaseConnGene): return jnp.array([input_index, output_index, enabled, weight]) - def distance(self, attrs1, attrs2): + def distance(self, state, attrs1, attrs2): return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight - def forward(self, attrs, inputs): + def forward(self, state, attrs, inputs): weight = attrs[0] return inputs * weight diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 2ebfd1b..8b81299 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene): def __init__(self): super().__init__() - def forward(self, attrs, inputs, is_output_node=False): + def forward(self, state, attrs, inputs, is_output_node=False): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index c06e7cf..1c46e17 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -56,22 +56,12 @@ class DefaultNodeGene(BaseNodeGene): self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_replace_rate = aggregation_replace_rate - def new_custom_attrs(self): + def new_attrs(self, state, key): return jnp.array( [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] ) - - def new_random_attrs(self, key): - return jnp.array([ - mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std, - self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate), - mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std, - self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate), - self.activation_default, - self.aggregation_default, - ]) - def mutate(self, key, node): + def mutate(self, state, key, node): k1, k2, k3, k4 = jax.random.split(key, num=4) index = node[0] @@ -87,7 +77,7 @@ class DefaultNodeGene(BaseNodeGene): return jnp.array([index, bias, res, act, agg]) - def distance(self, node1, node2): + def distance(self, state, node1, node2): return ( jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + @@ -95,7 +85,7 @@ class DefaultNodeGene(BaseNodeGene): (node1[4] != node2[4]) ) - def forward(self, attrs, inputs, is_output_node=False): + def forward(self, state, attrs, inputs, is_output_node=False): bias, res, act_idx, agg_idx = attrs z = agg(agg_idx, inputs, self.aggregation_options)