All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
k1, k2 = jax.random.split(randkey, 2)
def setup(self, key, state=State()):
k1, k2 = jax.random.split(key, 2)
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State(
randkey=k2,
return state.register(
species_randkey=k2,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys,
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)