All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, key, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(key)
def __call__(self, state, genome, nodes, conns, new_node_key):
k1, k2, randkey = jax.random.split(state.randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns
return state.update(randkey=randkey), nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):