modify method cal_spawn_numbers
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||
from .population import update_species, create_next_generation, speciate
|
||||
from .population import update_species, create_next_generation, speciate, tell
|
||||
|
||||
from .genome.activations import act_name2func
|
||||
from .genome.aggregations import agg_name2func
|
||||
|
||||
@@ -100,4 +100,4 @@ def create_forward_function(config):
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return forward
|
||||
return jit(forward)
|
||||
|
||||
@@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp
|
||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
jit_config):
|
||||
|
||||
generation += 1
|
||||
|
||||
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||
|
||||
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||
center_cons, generation, jit_config)
|
||||
|
||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, generation, jit_config)
|
||||
|
||||
idx2species, center_nodes, center_cons, species_info = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
@@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config):
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers in a species
|
||||
previous_size = species_info[:, 3].astype(jnp.int32)
|
||||
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
||||
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
|
||||
Reference in New Issue
Block a user