add species_number_calculate_by params in neat.py

according to (issue)[https://github.com/EMI-Group/tensorneat/issues/8]
This commit is contained in:
wls2002
2024-11-22 11:33:47 +08:00
parent b50128898c
commit 647e750b38
2 changed files with 56 additions and 6 deletions

View File

@@ -24,7 +24,14 @@ class NEAT(BaseAlgorithm):
min_species_size: int = 1, min_species_size: int = 1,
compatibility_threshold: float = 2.0, compatibility_threshold: float = 2.0,
species_fitness_func: Callable = jnp.max, species_fitness_func: Callable = jnp.max,
species_number_calculate_by: str = "rank",
): ):
assert species_number_calculate_by in [
"rank",
"fitness",
], "species_number_calculate_by should be either 'rank' or 'fitness'"
self.genome = genome self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size
self.species_controller = SpeciesController( self.species_controller = SpeciesController(
@@ -38,6 +45,7 @@ class NEAT(BaseAlgorithm):
min_species_size, min_species_size,
compatibility_threshold, compatibility_threshold,
species_fitness_func, species_fitness_func,
species_number_calculate_by,
) )
def setup(self, state=State()): def setup(self, state=State()):

View File

@@ -26,6 +26,7 @@ class SpeciesController(StatefulBaseClass):
min_species_size, min_species_size,
compatibility_threshold, compatibility_threshold,
species_fitness_func, species_fitness_func,
species_number_calculate_by,
): ):
self.pop_size = pop_size self.pop_size = pop_size
self.species_size = species_size self.species_size = species_size
@@ -38,6 +39,7 @@ class SpeciesController(StatefulBaseClass):
self.min_species_size = min_species_size self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold self.compatibility_threshold = compatibility_threshold
self.species_fitness_func = species_fitness_func self.species_fitness_func = species_fitness_func
self.species_number_calculate_by = species_number_calculate_by
def setup(self, state, first_nodes, first_conns): def setup(self, state, first_nodes, first_conns):
# the unique index (primary key) for each species # the unique index (primary key) for each species
@@ -111,7 +113,12 @@ class SpeciesController(StatefulBaseClass):
) )
# decide the number of members of each species by their fitness # decide the number of members of each species by their fitness
spawn_number = self._cal_spawn_numbers(species_state) if self.species_number_calculate_by == "rank":
spawn_number = self._cal_spawn_numbers_by_rank(species_state)
elif self.species_number_calculate_by == "fitness":
spawn_number = self._cal_spawn_numbers_by_fitness(species_state)
else:
raise ValueError("species_number_calculate_by must be 'rank' or 'fitness'")
k1, k2 = jax.random.split(state.randkey) k1, k2 = jax.random.split(state.randkey)
# crossover info # crossover info
@@ -234,7 +241,7 @@ class SpeciesController(StatefulBaseClass):
species_fitness, species_fitness,
) )
def _cal_spawn_numbers(self, species_state): def _cal_spawn_numbers_by_rank(self, species_state):
""" """
decide the number of members of each species by their fitness rank. decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members the species with higher fitness will have more members
@@ -250,11 +257,9 @@ class SpeciesController(StatefulBaseClass):
(valid_species_num + 1) * valid_species_num / 2 (valid_species_num + 1) * valid_species_num / 2
) # obtain 3 + 2 + 1 = 6 ) # obtain 3 + 2 + 1 = 6
# calculate the spawn number rate by the rank of each species
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1] rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(
is_species_valid, spawn_number_rate, 0
) # set invalid species to 0
target_spawn_number = jnp.floor( target_spawn_number = jnp.floor(
spawn_number_rate * self.pop_size spawn_number_rate * self.pop_size
@@ -276,6 +281,41 @@ class SpeciesController(StatefulBaseClass):
return spawn_number return spawn_number
def _cal_spawn_numbers_by_fitness(self, species_state):
"""
decide the number of members of each species by their fitness.
the species with higher fitness will have more members
"""
species_keys = species_state.species_keys
# the fitness of each species
species_fitness = species_state.best_fitness
# calculate the spawn number rate by the fitness of each species
spawn_number_rate = species_fitness / jnp.sum(
species_fitness, where=~jnp.isnan(species_fitness)
)
target_spawn_number = jnp.floor(
spawn_number_rate * self.pop_size
) # calculate member
# Avoid too much variation of numbers for a species
previous_size = species_state.member_count
spawn_number = (
previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness): def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
s_idx = self.species_arange s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size) p_idx = jnp.arange(self.pop_size)
@@ -503,7 +543,9 @@ class SpeciesController(StatefulBaseClass):
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation # complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness) new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(
state.species.best_fitness
)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness) best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
last_improved = jnp.where( last_improved = jnp.where(
new_created_mask, state.generation, state.species.last_improved new_created_mask, state.generation, state.species.last_improved