perfect! fix all bugs!
This commit is contained in:
@@ -64,7 +64,7 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||
|
||||
# jax.debug.print("spawn_number: {}", spawn_number)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = \
|
||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||
@@ -135,13 +135,16 @@ def cal_spawn_numbers(species_info, jit_config):
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
||||
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
||||
|
||||
# Avoid too much variation of numbers in a species
|
||||
previous_size = species_info[:, 3].astype(jnp.int32)
|
||||
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
||||
|
||||
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# spawn_number = target_spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
@@ -254,7 +257,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie))
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
# jax.debug.print("closest_idx: {}", closest_idx)
|
||||
|
||||
i2s = i2s.at[closest_idx].set(species_info[i, 0])
|
||||
cn = cn.at[i].set(pop_nodes[closest_idx])
|
||||
@@ -268,13 +272,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
_, idx2specie, center_nodes, center_cons, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
|
||||
|
||||
# jax.debug.print("species_info: \n{}", species_info)
|
||||
# jax.debug.print("idx2specie: \n{}", idx2specie)
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||
jax.debug.print("{}, {}", i, i2s)
|
||||
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
||||
current_species_existed = ~jnp.isnan(si[i, 0])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
@@ -324,8 +332,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance)
|
||||
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
# jax.debug.print("{}", o2p_distance)
|
||||
mask = close_enough_mask & cacheable_mask
|
||||
|
||||
# update species info
|
||||
|
||||
@@ -13,16 +13,16 @@ batch_size = 4
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 100
|
||||
pop_size = 50
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0
|
||||
conn_delete_prob = 0.4
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
node_delete_prob = 0.2
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3.0
|
||||
|
||||
Reference in New Issue
Block a user