update speciate function
This commit is contained in:
@@ -233,56 +233,60 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
|
|
||||||
# prepare distance functions
|
# prepare distance functions
|
||||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||||
s2p_distance_func = vmap(
|
|
||||||
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
|
||||||
)
|
|
||||||
|
|
||||||
# idx to specie key
|
# idx to specie key
|
||||||
idx2specie = jnp.full((pop_size,), jnp.nan) # I_INT means not assigned to any species
|
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||||
|
|
||||||
# part 1: find new centers
|
# the distance between genomes to its center genomes
|
||||||
# the distance between each species' center and each genome in population
|
o2c_distances = jnp.full((pop_size, ), jnp.inf)
|
||||||
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
|
||||||
|
|
||||||
def find_new_centers(i, carry):
|
# step 1: find new centers
|
||||||
i2s, cn, cc = carry
|
def cond_func(carry):
|
||||||
# find new center
|
i, i2s, cn, cc, o2c = carry
|
||||||
idx = argmin_with_mask(s2p_distance[i], mask=jnp.isnan(i2s))
|
species_key = species_info[i, 0]
|
||||||
|
# jax.debug.print("{}, {}", i, species_key)
|
||||||
|
return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing
|
||||||
|
|
||||||
# check species[i] exist or not
|
def body_func(carry):
|
||||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
i, i2s, cn, cc, o2c = carry
|
||||||
idx = jnp.where(~jnp.isnan(species_info[i, 0]), idx, I_INT)
|
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||||
i = jnp.where(~jnp.isnan(species_info[i, 0]), i, I_INT)
|
|
||||||
|
|
||||||
i2s = i2s.at[idx].set(species_info[i, 0])
|
# find the closest one
|
||||||
cn = cn.at[i].set(pop_nodes[idx])
|
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie))
|
||||||
cc = cc.at[i].set(pop_cons[idx])
|
|
||||||
return i2s, cn, cc
|
|
||||||
|
|
||||||
idx2specie, center_nodes, center_cons = \
|
i2s = i2s.at[closest_idx].set(species_info[i, 0])
|
||||||
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
cn = cn.at[i].set(pop_nodes[closest_idx])
|
||||||
|
cc = cc.at[i].set(pop_cons[closest_idx])
|
||||||
|
|
||||||
|
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||||
|
o2c = o2c.at[closest_idx].set(0)
|
||||||
|
|
||||||
|
return i + 1, i2s, cn, cc, o2c
|
||||||
|
|
||||||
|
_, idx2specie, center_nodes, center_cons, o2c_distances = \
|
||||||
|
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
|
||||||
|
|
||||||
# part 2: assign members to each species
|
# part 2: assign members to each species
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
i, i2s, cn, cc, si, ck = carry # si is short for species_info, ck is short for current key
|
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||||
not_reach_species_upper_bounds = i < species_size
|
not_reach_species_upper_bounds = i < species_size
|
||||||
return not_all_assigned & not_reach_species_upper_bounds
|
return not_all_assigned & not_reach_species_upper_bounds
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
i, i2s, cn, cc, si, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||||
|
|
||||||
i2s, scn, scc, si, ck = jax.lax.cond(
|
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
|
||||||
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
||||||
create_new_specie, # if not existing, create a new specie
|
create_new_species, # if not existing, create a new specie
|
||||||
update_exist_specie, # if existing, update the specie
|
update_exist_specie, # if existing, update the specie
|
||||||
(i, i2s, cn, cc, si, ck)
|
(i, i2s, cn, cc, si, o2c, ck)
|
||||||
)
|
)
|
||||||
|
|
||||||
return i + 1, i2s, scn, scc, si, ck
|
return i + 1, i2s, scn, scc, si, o2c, ck
|
||||||
|
|
||||||
def create_new_specie(carry):
|
def create_new_species(carry):
|
||||||
i, i2s, cn, cc, si, ck = carry
|
i, i2s, cn, cc, si, o2c, ck = carry
|
||||||
|
|
||||||
# pick the first one who has not been assigned to any species
|
# pick the first one who has not been assigned to any species
|
||||||
idx = fetch_first(jnp.isnan(i2s))
|
idx = fetch_first(jnp.isnan(i2s))
|
||||||
@@ -291,43 +295,58 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
# [key, best score, last update generation, members_count]
|
# [key, best score, last update generation, members_count]
|
||||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
||||||
i2s = i2s.at[idx].set(ck)
|
i2s = i2s.at[idx].set(ck)
|
||||||
|
o2c = o2c.at[idx].set(0)
|
||||||
|
|
||||||
# update center genomes
|
# update center genomes
|
||||||
cn = cn.at[i].set(pop_nodes[idx])
|
cn = cn.at[i].set(pop_nodes[idx])
|
||||||
cc = cc.at[i].set(pop_cons[idx])
|
cc = cc.at[i].set(pop_cons[idx])
|
||||||
|
|
||||||
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||||
return i2s, cn, cc, si, ck + 1 # change to next new speciate key
|
|
||||||
|
# when a new species is created, it needs to be updated, thus do not change i
|
||||||
|
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
|
||||||
|
|
||||||
def update_exist_specie(carry):
|
def update_exist_specie(carry):
|
||||||
i, i2s, cn, cc, si, ck = carry
|
i, i2s, cn, cc, si, o2c, ck = carry
|
||||||
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||||
return i2s, cn, cc, si, ck
|
|
||||||
|
# turn to next species
|
||||||
|
return i + 1, i2s, cn, cc, si, o2c, ck
|
||||||
|
|
||||||
def speciate_by_threshold(carry):
|
def speciate_by_threshold(carry):
|
||||||
i, i2s, cn, cc, si = carry
|
i, i2s, cn, cc, si, o2c = carry
|
||||||
|
|
||||||
# distance between such center genome and ppo genomes
|
# distance between such center genome and ppo genomes
|
||||||
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||||
|
|
||||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||||
i2s = jnp.where(close_enough_mask & jnp.isnan(i2s), si[i, 0], i2s)
|
cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance)
|
||||||
return i2s
|
|
||||||
|
mask = close_enough_mask & cacheable_mask
|
||||||
|
|
||||||
|
# update species info
|
||||||
|
i2s = jnp.where(mask, si[i, 0], i2s)
|
||||||
|
|
||||||
|
# update distance between centers
|
||||||
|
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||||
|
|
||||||
|
return i2s, o2c
|
||||||
|
|
||||||
species_keys = species_info[:, 0]
|
species_keys = species_info[:, 0]
|
||||||
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
||||||
|
|
||||||
|
|
||||||
# update idx2specie
|
# update idx2specie
|
||||||
_, idx2specie, center_nodes, center_cons, species_info, _ = jax.lax.while_loop(
|
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
|
||||||
cond_func,
|
cond_func,
|
||||||
body_func,
|
body_func,
|
||||||
(0, idx2specie, center_nodes, center_cons, species_info, current_new_key)
|
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
# this condition can only happen when the number of species is reached species upper bounds
|
||||||
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
|
idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie)
|
||||||
|
|
||||||
# update members count
|
# update members count
|
||||||
def count_members(idx):
|
def count_members(idx):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ num_inputs = 2
|
|||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 50
|
init_maximum_nodes = 50
|
||||||
init_maximum_connections = 50
|
init_maximum_connections = 50
|
||||||
init_maximum_species = 10
|
init_maximum_species = 100
|
||||||
expand_coe = 1.5
|
expand_coe = 1.5
|
||||||
pre_expand_threshold = 0.75
|
pre_expand_threshold = 0.75
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
@@ -13,7 +13,7 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 2000
|
pop_size = 100
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
Reference in New Issue
Block a user