From 647e750b38ab0a5cbd6aad8c8c0f03ae5ee494c7 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 22 Nov 2024 11:33:47 +0800 Subject: [PATCH] add species_number_calculate_by params in neat.py according to (issue)[https://github.com/EMI-Group/tensorneat/issues/8] --- src/tensorneat/algorithm/neat/neat.py | 8 ++++ src/tensorneat/algorithm/neat/species.py | 54 +++++++++++++++++++++--- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/tensorneat/algorithm/neat/neat.py b/src/tensorneat/algorithm/neat/neat.py index 09652fa..0487e29 100644 --- a/src/tensorneat/algorithm/neat/neat.py +++ b/src/tensorneat/algorithm/neat/neat.py @@ -24,7 +24,14 @@ class NEAT(BaseAlgorithm): min_species_size: int = 1, compatibility_threshold: float = 2.0, species_fitness_func: Callable = jnp.max, + species_number_calculate_by: str = "rank", ): + + assert species_number_calculate_by in [ + "rank", + "fitness", + ], "species_number_calculate_by should be either 'rank' or 'fitness'" + self.genome = genome self.pop_size = pop_size self.species_controller = SpeciesController( @@ -38,6 +45,7 @@ class NEAT(BaseAlgorithm): min_species_size, compatibility_threshold, species_fitness_func, + species_number_calculate_by, ) def setup(self, state=State()): diff --git a/src/tensorneat/algorithm/neat/species.py b/src/tensorneat/algorithm/neat/species.py index 0537b70..7e34ca6 100644 --- a/src/tensorneat/algorithm/neat/species.py +++ b/src/tensorneat/algorithm/neat/species.py @@ -26,6 +26,7 @@ class SpeciesController(StatefulBaseClass): min_species_size, compatibility_threshold, species_fitness_func, + species_number_calculate_by, ): self.pop_size = pop_size self.species_size = species_size @@ -38,6 +39,7 @@ class SpeciesController(StatefulBaseClass): self.min_species_size = min_species_size self.compatibility_threshold = compatibility_threshold self.species_fitness_func = species_fitness_func + self.species_number_calculate_by = species_number_calculate_by def setup(self, state, first_nodes, first_conns): # the unique index (primary key) for each species @@ -111,7 +113,12 @@ class SpeciesController(StatefulBaseClass): ) # decide the number of members of each species by their fitness - spawn_number = self._cal_spawn_numbers(species_state) + if self.species_number_calculate_by == "rank": + spawn_number = self._cal_spawn_numbers_by_rank(species_state) + elif self.species_number_calculate_by == "fitness": + spawn_number = self._cal_spawn_numbers_by_fitness(species_state) + else: + raise ValueError("species_number_calculate_by must be 'rank' or 'fitness'") k1, k2 = jax.random.split(state.randkey) # crossover info @@ -234,7 +241,7 @@ class SpeciesController(StatefulBaseClass): species_fitness, ) - def _cal_spawn_numbers(self, species_state): + def _cal_spawn_numbers_by_rank(self, species_state): """ decide the number of members of each species by their fitness rank. the species with higher fitness will have more members @@ -250,11 +257,9 @@ class SpeciesController(StatefulBaseClass): (valid_species_num + 1) * valid_species_num / 2 ) # obtain 3 + 2 + 1 = 6 + # calculate the spawn number rate by the rank of each species rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1] spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where( - is_species_valid, spawn_number_rate, 0 - ) # set invalid species to 0 target_spawn_number = jnp.floor( spawn_number_rate * self.pop_size @@ -276,6 +281,41 @@ class SpeciesController(StatefulBaseClass): return spawn_number + def _cal_spawn_numbers_by_fitness(self, species_state): + """ + decide the number of members of each species by their fitness. + the species with higher fitness will have more members + """ + + species_keys = species_state.species_keys + + # the fitness of each species + species_fitness = species_state.best_fitness + + # calculate the spawn number rate by the fitness of each species + spawn_number_rate = species_fitness / jnp.sum( + species_fitness, where=~jnp.isnan(species_fitness) + ) + target_spawn_number = jnp.floor( + spawn_number_rate * self.pop_size + ) # calculate member + + # Avoid too much variation of numbers for a species + previous_size = species_state.member_count + spawn_number = ( + previous_size + + (target_spawn_number - previous_size) * self.spawn_number_change_rate + ) + spawn_number = spawn_number.astype(jnp.int32) + + # must control the sum of spawn_number to be equal to pop_size + error = self.pop_size - jnp.sum(spawn_number) + + # add error to the first species to control the sum of spawn_number + spawn_number = spawn_number.at[0].add(error) + + return spawn_number + def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness): s_idx = self.species_arange p_idx = jnp.arange(self.pop_size) @@ -503,7 +543,9 @@ class SpeciesController(StatefulBaseClass): idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) # complete info of species which is created in this generation - new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness) + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan( + state.species.best_fitness + ) best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness) last_improved = jnp.where( new_created_mask, state.generation, state.species.last_improved