modify method cal_spawn_numbers

spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
wls2002
2023-07-01 13:36:19 +08:00
parent 896082900a
commit f6dcb97df8
7 changed files with 64 additions and 21 deletions

View File

@@ -2,7 +2,7 @@
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate
from .population import update_species, create_next_generation, speciate, tell
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -100,4 +100,4 @@ def create_forward_function(config):
elif config['forward_way'] == 'common':
return jit(common_forward)
return forward
return jit(forward)

View File

@@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
@jit
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
jit_config):
generation += 1
k1, k2, randkey = jax.random.split(randkey, 3)
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, generation, jit_config)
idx2species, center_nodes, center_cons, species_info = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
@@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config):
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
# Avoid too much variation of numbers in a species
previous_size = species_info[:, 3].astype(jnp.int32)
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)