diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index 90ccbac..0119a28 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -64,7 +64,7 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce # decide the number of members of each species by their fitness spawn_number = cal_spawn_numbers(species_info, jit_config) - + # jax.debug.print("spawn_number: {}", spawn_number) # crossover info winner, loser, elite_mask = \ create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config) @@ -135,13 +135,16 @@ def cal_spawn_numbers(species_info, jit_config): spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member + # jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number) # Avoid too much variation of numbers in a species previous_size = species_info[:, 3].astype(jnp.int32) spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate'] - + # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) spawn_number = spawn_number.astype(jnp.int32) + # spawn_number = target_spawn_number.astype(jnp.int32) + # must control the sum of spawn_number to be equal to pop_size error = jit_config['pop_size'] - jnp.sum(spawn_number) spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number @@ -254,7 +257,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) # find the closest one - closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie)) + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) + # jax.debug.print("closest_idx: {}", closest_idx) i2s = i2s.at[closest_idx].set(species_info[i, 0]) cn = cn.at[i].set(pop_nodes[closest_idx]) @@ -268,13 +272,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener _, idx2specie, center_nodes, center_cons, o2c_distances = \ jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances)) + # jax.debug.print("species_info: \n{}", species_info) + # jax.debug.print("idx2specie: \n{}", idx2specie) + # part 2: assign members to each species def cond_func(carry): i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key - jax.debug.print("{}, {}", i, i2s) + # jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si) + current_species_existed = ~jnp.isnan(si[i, 0]) not_all_assigned = jnp.any(jnp.isnan(i2s)) not_reach_species_upper_bounds = i < species_size - return not_all_assigned & not_reach_species_upper_bounds + return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds) def body_func(carry): i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons @@ -324,8 +332,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] # when a genome is not assigned or the distance between its current center is bigger than this center - cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance) - + cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) + # jax.debug.print("{}", o2p_distance) mask = close_enough_mask & cacheable_mask # update species info diff --git a/configs/default_config.ini b/configs/default_config.ini index 67bdb92..3533f65 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -13,16 +13,16 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 1000 fitness_criterion = "max" -pop_size = 100 +pop_size = 50 [genome] compatibility_disjoint = 1.0 compatibility_weight = 0.5 -conn_add_prob = 0.5 +conn_add_prob = 0.4 conn_add_trials = 1 -conn_delete_prob = 0 +conn_delete_prob = 0.4 node_add_prob = 0.2 -node_delete_prob = 0 +node_delete_prob = 0.2 [species] compatibility_threshold = 3.0