fix a bug in stagnation
This commit is contained in:
@@ -23,6 +23,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||
center_cons, generation, jit_config)
|
||||
|
||||
|
||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, generation, jit_config)
|
||||
|
||||
@@ -30,6 +31,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
|
||||
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||
|
||||
|
||||
@@ -111,7 +113,7 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
|
||||
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
|
||||
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
|
||||
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
|
||||
species_fitness = jnp.where(spe_st, jnp.nan, species_fitness)
|
||||
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
|
||||
|
||||
return species_fitness, species_info, center_nodes, center_cons
|
||||
|
||||
@@ -269,6 +271,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||
jax.debug.print("{}, {}", i, i2s)
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
Reference in New Issue
Block a user