From 5fdaf6806c2241a49975283255bcc43fc60a1bc9 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 22 Nov 2024 12:17:49 +0800 Subject: [PATCH] add normalize in _cal_spawn_numbers_by_fitness --- src/tensorneat/algorithm/neat/species.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/tensorneat/algorithm/neat/species.py b/src/tensorneat/algorithm/neat/species.py index 7e34ca6..e868d2b 100644 --- a/src/tensorneat/algorithm/neat/species.py +++ b/src/tensorneat/algorithm/neat/species.py @@ -271,6 +271,12 @@ class SpeciesController(StatefulBaseClass): previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate ) + + # maintain min_species_size, this will not influence nan + spawn_number = jnp.where( + spawn_number < self.min_species_size, self.min_species_size, spawn_number + ) + # convert to int, this will also make nan to 0 spawn_number = spawn_number.astype(jnp.int32) # must control the sum of spawn_number to be equal to pop_size @@ -287,11 +293,14 @@ class SpeciesController(StatefulBaseClass): the species with higher fitness will have more members """ - species_keys = species_state.species_keys - # the fitness of each species species_fitness = species_state.best_fitness + # normalize the fitness before calculating the spawn number + # consider that the fitness may be negative + # in this way the species with the lowest fitness will have spawn_number = 0 + species_fitness = species_fitness - jnp.min(species_fitness) + # calculate the spawn number rate by the fitness of each species spawn_number_rate = species_fitness / jnp.sum( species_fitness, where=~jnp.isnan(species_fitness) @@ -306,6 +315,12 @@ class SpeciesController(StatefulBaseClass): previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate ) + # maintain min_species_size, this will not influence nan + spawn_number = jnp.where( + spawn_number < self.min_species_size, self.min_species_size, spawn_number + ) + + # convert to int, this will also make nan to 0 spawn_number = spawn_number.astype(jnp.int32) # must control the sum of spawn_number to be equal to pop_size