From bc8267bad09d2fd657d07ec0236198281e4ee7bb Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 31 May 2024 16:18:02 +0800 Subject: [PATCH] fix bug for using state.randkey in mutate of the gene --- tensorneat/algorithm/neat/gene/node/default.py | 2 +- tensorneat/algorithm/neat/gene/node/default_without_response.py | 2 +- tensorneat/algorithm/neat/gene/node/normalized.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 8a884b7..bb90780 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -74,7 +74,7 @@ class DefaultNodeGene(BaseNodeGene): return jnp.array([bias, res, act, agg]) def mutate(self, state, randkey, node): - k1, k2, k3, k4 = jax.random.split(state.randkey, num=4) + k1, k2, k3, k4 = jax.random.split(randkey, num=4) index = node[0] bias = mutate_float( diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py index 9f4fd5b..e29ae4f 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -62,7 +62,7 @@ class NodeGeneWithoutResponse(BaseNodeGene): return jnp.array([bias, act, agg]) def mutate(self, state, randkey, node): - k1, k2, k3, k4 = jax.random.split(state.randkey, num=4) + k1, k2, k3, k4 = jax.random.split(randkey, num=4) index = node[0] bias = mutate_float( diff --git a/tensorneat/algorithm/neat/gene/node/normalized.py b/tensorneat/algorithm/neat/gene/node/normalized.py index 62a48df..7dfb045 100644 --- a/tensorneat/algorithm/neat/gene/node/normalized.py +++ b/tensorneat/algorithm/neat/gene/node/normalized.py @@ -95,7 +95,7 @@ class NormalizedNode(BaseNodeGene): return jnp.array([bias, act, agg, mean, std, alpha, beta]) def mutate(self, state, randkey, node): - k1, k2, k3, k4, k5, k6 = jax.random.split(state.randkey, num=6) + k1, k2, k3, k4, k5, k6 = jax.random.split(randkey, num=6) index = node[0] bias = mutate_float(