perfect! fix all bugs!
This commit is contained in:
@@ -64,7 +64,7 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
|
|||||||
|
|
||||||
# decide the number of members of each species by their fitness
|
# decide the number of members of each species by their fitness
|
||||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||||
|
# jax.debug.print("spawn_number: {}", spawn_number)
|
||||||
# crossover info
|
# crossover info
|
||||||
winner, loser, elite_mask = \
|
winner, loser, elite_mask = \
|
||||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||||
@@ -135,13 +135,16 @@ def cal_spawn_numbers(species_info, jit_config):
|
|||||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||||
|
|
||||||
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
||||||
|
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
||||||
|
|
||||||
# Avoid too much variation of numbers in a species
|
# Avoid too much variation of numbers in a species
|
||||||
previous_size = species_info[:, 3].astype(jnp.int32)
|
previous_size = species_info[:, 3].astype(jnp.int32)
|
||||||
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
||||||
|
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
|
||||||
spawn_number = spawn_number.astype(jnp.int32)
|
spawn_number = spawn_number.astype(jnp.int32)
|
||||||
|
|
||||||
|
# spawn_number = target_spawn_number.astype(jnp.int32)
|
||||||
|
|
||||||
# must control the sum of spawn_number to be equal to pop_size
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||||
@@ -254,7 +257,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||||
|
|
||||||
# find the closest one
|
# find the closest one
|
||||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(idx2specie))
|
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||||
|
# jax.debug.print("closest_idx: {}", closest_idx)
|
||||||
|
|
||||||
i2s = i2s.at[closest_idx].set(species_info[i, 0])
|
i2s = i2s.at[closest_idx].set(species_info[i, 0])
|
||||||
cn = cn.at[i].set(pop_nodes[closest_idx])
|
cn = cn.at[i].set(pop_nodes[closest_idx])
|
||||||
@@ -268,13 +272,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
_, idx2specie, center_nodes, center_cons, o2c_distances = \
|
_, idx2specie, center_nodes, center_cons, o2c_distances = \
|
||||||
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
|
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
|
||||||
|
|
||||||
|
# jax.debug.print("species_info: \n{}", species_info)
|
||||||
|
# jax.debug.print("idx2specie: \n{}", idx2specie)
|
||||||
|
|
||||||
# part 2: assign members to each species
|
# part 2: assign members to each species
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||||
jax.debug.print("{}, {}", i, i2s)
|
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
||||||
|
current_species_existed = ~jnp.isnan(si[i, 0])
|
||||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||||
not_reach_species_upper_bounds = i < species_size
|
not_reach_species_upper_bounds = i < species_size
|
||||||
return not_all_assigned & not_reach_species_upper_bounds
|
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||||
@@ -324,8 +332,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||||
|
|
||||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||||
cacheable_mask = jnp.isnan(i2s) | (o2c > o2p_distance)
|
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||||
|
# jax.debug.print("{}", o2p_distance)
|
||||||
mask = close_enough_mask & cacheable_mask
|
mask = close_enough_mask & cacheable_mask
|
||||||
|
|
||||||
# update species info
|
# update species info
|
||||||
|
|||||||
@@ -13,16 +13,16 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 100
|
pop_size = 50
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
compatibility_weight = 0.5
|
compatibility_weight = 0.5
|
||||||
conn_add_prob = 0.5
|
conn_add_prob = 0.4
|
||||||
conn_add_trials = 1
|
conn_add_trials = 1
|
||||||
conn_delete_prob = 0
|
conn_delete_prob = 0.4
|
||||||
node_add_prob = 0.2
|
node_add_prob = 0.2
|
||||||
node_delete_prob = 0
|
node_delete_prob = 0.2
|
||||||
|
|
||||||
[species]
|
[species]
|
||||||
compatibility_threshold = 3.0
|
compatibility_threshold = 3.0
|
||||||
|
|||||||
Reference in New Issue
Block a user