perfect! fix all bugs!

This commit is contained in:
wls2002
2023-07-01 17:46:01 +08:00
parent eb15ff72fe
commit e711146f41
2 changed files with 19 additions and 11 deletions

View File

@@ -64,7 +64,7 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# jax.debug.print("spawn_number: {}", spawn_number)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
@@ -135,13 +135,16 @@ def cal_spawn_numbers(species_info, jit_config):
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
# Avoid too much variation of numbers in a species
previous_size = species_info[:, 3].astype(jnp.int32)
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# spawn_number = target_spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
@@ -254,7 +257,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie))
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(species_info[i, 0])
cn = cn.at[i].set(pop_nodes[closest_idx])
@@ -268,13 +272,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
_, idx2specie, center_nodes, center_cons, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
# jax.debug.print("species_info: \n{}", species_info)
# jax.debug.print("idx2specie: \n{}", idx2specie)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
jax.debug.print("{}, {}", i, i2s)
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
def body_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
@@ -324,8 +332,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance)
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info

View File

@@ -13,16 +13,16 @@ batch_size = 4
fitness_threshold = 100000
generation_limit = 1000
fitness_criterion = "max"
pop_size = 100
pop_size = 50
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
conn_delete_prob = 0.4
node_add_prob = 0.2
node_delete_prob = 0
node_delete_prob = 0.2
[species]
compatibility_threshold = 3.0