All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_attrs(self, state, key):
return jnp.array(
def new_attrs(self, state):
return state, jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, state, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
def mutate(self, state, node):
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
def distance(self, state, node1, node2):
return (
return state, (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
(node1[3] != node2[3]) +
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
lambda: act(act_idx, z, self.activation_options)
)
return z
return state, z