diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index 44740e7..aab6907 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -233,56 +233,60 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener # prepare distance functions o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population - s2p_distance_func = vmap( - o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population - ) # idx to specie key - idx2specie = jnp.full((pop_size,), jnp.nan) # I_INT means not assigned to any species + idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species - # part 1: find new centers - # the distance between each species' center and each genome in population - s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config) + # the distance between genomes to its center genomes + o2c_distances = jnp.full((pop_size, ), jnp.inf) - def find_new_centers(i, carry): - i2s, cn, cc = carry - # find new center - idx = argmin_with_mask(s2p_distance[i], mask=jnp.isnan(i2s)) + # step 1: find new centers + def cond_func(carry): + i, i2s, cn, cc, o2c = carry + species_key = species_info[i, 0] + # jax.debug.print("{}, {}", i, species_key) + return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing - # check species[i] exist or not - # if not exist, set idx and i to I_INT, jax will not do array value assignment - idx = jnp.where(~jnp.isnan(species_info[i, 0]), idx, I_INT) - i = jnp.where(~jnp.isnan(species_info[i, 0]), i, I_INT) + def body_func(carry): + i, i2s, cn, cc, o2c = carry + distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) - i2s = i2s.at[idx].set(species_info[i, 0]) - cn = cn.at[i].set(pop_nodes[idx]) - cc = cc.at[i].set(pop_cons[idx]) - return i2s, cn, cc + # find the closest one + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie)) - idx2specie, center_nodes, center_cons = \ - jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons)) + i2s = i2s.at[closest_idx].set(species_info[i, 0]) + cn = cn.at[i].set(pop_nodes[closest_idx]) + cc = cc.at[i].set(pop_cons[closest_idx]) + + # the genome with closest_idx will become the new center, thus its distance to center is 0. + o2c = o2c.at[closest_idx].set(0) + + return i + 1, i2s, cn, cc, o2c + + _, idx2specie, center_nodes, center_cons, o2c_distances = \ + jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances)) # part 2: assign members to each species def cond_func(carry): - i, i2s, cn, cc, si, ck = carry # si is short for species_info, ck is short for current key + i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key not_all_assigned = jnp.any(jnp.isnan(i2s)) not_reach_species_upper_bounds = i < species_size return not_all_assigned & not_reach_species_upper_bounds def body_func(carry): - i, i2s, cn, cc, si, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - i2s, scn, scc, si, ck = jax.lax.cond( + _, i2s, scn, scc, si, o2c, ck = jax.lax.cond( jnp.isnan(si[i, 0]), # whether the current species is existing or not - create_new_specie, # if not existing, create a new specie + create_new_species, # if not existing, create a new specie update_exist_specie, # if existing, update the specie - (i, i2s, cn, cc, si, ck) + (i, i2s, cn, cc, si, o2c, ck) ) - return i + 1, i2s, scn, scc, si, ck + return i + 1, i2s, scn, scc, si, o2c, ck - def create_new_specie(carry): - i, i2s, cn, cc, si, ck = carry + def create_new_species(carry): + i, i2s, cn, cc, si, o2c, ck = carry # pick the first one who has not been assigned to any species idx = fetch_first(jnp.isnan(i2s)) @@ -291,43 +295,58 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener # [key, best score, last update generation, members_count] si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0])) i2s = i2s.at[idx].set(ck) + o2c = o2c.at[idx].set(0) # update center genomes cn = cn.at[i].set(pop_nodes[idx]) cc = cc.at[i].set(pop_cons[idx]) - i2s = speciate_by_threshold((i, i2s, cn, cc, si)) - return i2s, cn, cc, si, ck + 1 # change to next new speciate key + i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) + + # when a new species is created, it needs to be updated, thus do not change i + return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key def update_exist_specie(carry): - i, i2s, cn, cc, si, ck = carry - i2s = speciate_by_threshold((i, i2s, cn, cc, si)) - return i2s, cn, cc, si, ck + i, i2s, cn, cc, si, o2c, ck = carry + i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) + + # turn to next species + return i + 1, i2s, cn, cc, si, o2c, ck def speciate_by_threshold(carry): - i, i2s, cn, cc, si = carry + i, i2s, cn, cc, si, o2c = carry # distance between such center genome and ppo genomes o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] - # when it is close enough, assign it to the species, remember not to update genome has already been assigned - i2s = jnp.where(close_enough_mask & jnp.isnan(i2s), si[i, 0], i2s) - return i2s + # when a genome is not assigned or the distance between its current center is bigger than this center + cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance) + + mask = close_enough_mask & cacheable_mask + + # update species info + i2s = jnp.where(mask, si[i, 0], i2s) + + # update distance between centers + o2c = jnp.where(mask, o2p_distance, o2c) + + return i2s, o2c species_keys = species_info[:, 0] current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1 + # update idx2specie - _, idx2specie, center_nodes, center_cons, species_info, _ = jax.lax.while_loop( + _, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop( cond_func, body_func, - (0, idx2specie, center_nodes, center_cons, species_info, current_new_key) + (0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key) ) # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition seems to be only happened when the number of species is reached species upper bounds - idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie) + # this condition can only happen when the number of species is reached species upper bounds + idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie) # update members count def count_members(idx): diff --git a/configs/default_config.ini b/configs/default_config.ini index 0860532..4d56a33 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -3,7 +3,7 @@ num_inputs = 2 num_outputs = 1 init_maximum_nodes = 50 init_maximum_connections = 50 -init_maximum_species = 10 +init_maximum_species = 100 expand_coe = 1.5 pre_expand_threshold = 0.75 forward_way = "pop" @@ -13,7 +13,7 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 1000 fitness_criterion = "max" -pop_size = 2000 +pop_size = 100 [genome] compatibility_disjoint = 1.0