All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, randkey):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
def setup(self, key, state=State()):
|
||||
k1, k2 = jax.random.split(key, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=k2,
|
||||
return state.register(
|
||||
species_randkey=k2,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
st = (
|
||||
(species_fitness[idx] <= state.best_fitness[
|
||||
idx]) & # not better than the best fitness of the species
|
||||
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
|
||||
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user