fix bug for using state.randkey in mutate of the gene
This commit is contained in:
@@ -74,7 +74,7 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
return jnp.array([bias, res, act, agg])
|
||||
|
||||
def mutate(self, state, randkey, node):
|
||||
k1, k2, k3, k4 = jax.random.split(state.randkey, num=4)
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(
|
||||
|
||||
@@ -62,7 +62,7 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
return jnp.array([bias, act, agg])
|
||||
|
||||
def mutate(self, state, randkey, node):
|
||||
k1, k2, k3, k4 = jax.random.split(state.randkey, num=4)
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(
|
||||
|
||||
@@ -95,7 +95,7 @@ class NormalizedNode(BaseNodeGene):
|
||||
return jnp.array([bias, act, agg, mean, std, alpha, beta])
|
||||
|
||||
def mutate(self, state, randkey, node):
|
||||
k1, k2, k3, k4, k5, k6 = jax.random.split(state.randkey, num=6)
|
||||
k1, k2, k3, k4, k5, k6 = jax.random.split(randkey, num=6)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(
|
||||
|
||||
Reference in New Issue
Block a user