add normalize in _cal_spawn_numbers_by_fitness

This commit is contained in:
wls2002
2024-11-22 12:17:49 +08:00
parent 647e750b38
commit 5fdaf6806c

View File

@@ -271,6 +271,12 @@ class SpeciesController(StatefulBaseClass):
previous_size previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate + (target_spawn_number - previous_size) * self.spawn_number_change_rate
) )
# maintain min_species_size, this will not influence nan
spawn_number = jnp.where(
spawn_number < self.min_species_size, self.min_species_size, spawn_number
)
# convert to int, this will also make nan to 0
spawn_number = spawn_number.astype(jnp.int32) spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size # must control the sum of spawn_number to be equal to pop_size
@@ -287,11 +293,14 @@ class SpeciesController(StatefulBaseClass):
the species with higher fitness will have more members the species with higher fitness will have more members
""" """
species_keys = species_state.species_keys
# the fitness of each species # the fitness of each species
species_fitness = species_state.best_fitness species_fitness = species_state.best_fitness
# normalize the fitness before calculating the spawn number
# consider that the fitness may be negative
# in this way the species with the lowest fitness will have spawn_number = 0
species_fitness = species_fitness - jnp.min(species_fitness)
# calculate the spawn number rate by the fitness of each species # calculate the spawn number rate by the fitness of each species
spawn_number_rate = species_fitness / jnp.sum( spawn_number_rate = species_fitness / jnp.sum(
species_fitness, where=~jnp.isnan(species_fitness) species_fitness, where=~jnp.isnan(species_fitness)
@@ -306,6 +315,12 @@ class SpeciesController(StatefulBaseClass):
previous_size previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate + (target_spawn_number - previous_size) * self.spawn_number_change_rate
) )
# maintain min_species_size, this will not influence nan
spawn_number = jnp.where(
spawn_number < self.min_species_size, self.min_species_size, spawn_number
)
# convert to int, this will also make nan to 0
spawn_number = spawn_number.astype(jnp.int32) spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size # must control the sum of spawn_number to be equal to pop_size