All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array(
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def mutate(self, state, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
def mutate(self, state, node):
|
||||
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
||||
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
||||
|
||||
return jnp.array([index, bias, res, act, agg])
|
||||
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
|
||||
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
return state, (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
(node1[3] != node2[3]) +
|
||||
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
return state, z
|
||||
|
||||
Reference in New Issue
Block a user