fix a bug in stagnation

This commit is contained in:
wls2002
2023-07-01 16:55:45 +08:00
parent 2a6e958408
commit eb15ff72fe
4 changed files with 122 additions and 3 deletions

View File

@@ -23,6 +23,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, generation, jit_config)
@@ -30,6 +31,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@@ -111,7 +113,7 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(spe_st, jnp.nan, species_fitness)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
@@ -269,6 +271,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
jax.debug.print("{}, {}", i, i2s)
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds