make fully stateful in module gene.

This commit is contained in:
wls2002
2024-05-25 16:13:41 +08:00
parent 3b2f917aee
commit 625c261a49
5 changed files with 21 additions and 36 deletions

View File

@@ -56,22 +56,12 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self):
def new_attrs(self, state, key):
return jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def new_random_attrs(self, key):
return jnp.array([
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
self.activation_default,
self.aggregation_default,
])
def mutate(self, key, node):
def mutate(self, state, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0]
@@ -87,7 +77,7 @@ class DefaultNodeGene(BaseNodeGene):
return jnp.array([index, bias, res, act, agg])
def distance(self, node1, node2):
def distance(self, state, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
@@ -95,7 +85,7 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4])
)
def forward(self, attrs, inputs, is_output_node=False):
def forward(self, state, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)