add species_number_calculate_by params in neat.py

according to (issue)[https://github.com/EMI-Group/tensorneat/issues/8]
This commit is contained in:
wls2002
2024-11-22 11:33:47 +08:00
parent b50128898c
commit 647e750b38
2 changed files with 56 additions and 6 deletions

View File

@@ -24,7 +24,14 @@ class NEAT(BaseAlgorithm):
min_species_size: int = 1,
compatibility_threshold: float = 2.0,
species_fitness_func: Callable = jnp.max,
species_number_calculate_by: str = "rank",
):
assert species_number_calculate_by in [
"rank",
"fitness",
], "species_number_calculate_by should be either 'rank' or 'fitness'"
self.genome = genome
self.pop_size = pop_size
self.species_controller = SpeciesController(
@@ -38,6 +45,7 @@ class NEAT(BaseAlgorithm):
min_species_size,
compatibility_threshold,
species_fitness_func,
species_number_calculate_by,
)
def setup(self, state=State()):

View File

@@ -26,6 +26,7 @@ class SpeciesController(StatefulBaseClass):
min_species_size,
compatibility_threshold,
species_fitness_func,
species_number_calculate_by,
):
self.pop_size = pop_size
self.species_size = species_size
@@ -38,6 +39,7 @@ class SpeciesController(StatefulBaseClass):
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_fitness_func = species_fitness_func
self.species_number_calculate_by = species_number_calculate_by
def setup(self, state, first_nodes, first_conns):
# the unique index (primary key) for each species
@@ -111,7 +113,12 @@ class SpeciesController(StatefulBaseClass):
)
# decide the number of members of each species by their fitness
spawn_number = self._cal_spawn_numbers(species_state)
if self.species_number_calculate_by == "rank":
spawn_number = self._cal_spawn_numbers_by_rank(species_state)
elif self.species_number_calculate_by == "fitness":
spawn_number = self._cal_spawn_numbers_by_fitness(species_state)
else:
raise ValueError("species_number_calculate_by must be 'rank' or 'fitness'")
k1, k2 = jax.random.split(state.randkey)
# crossover info
@@ -234,7 +241,7 @@ class SpeciesController(StatefulBaseClass):
species_fitness,
)
def _cal_spawn_numbers(self, species_state):
def _cal_spawn_numbers_by_rank(self, species_state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
@@ -250,11 +257,9 @@ class SpeciesController(StatefulBaseClass):
(valid_species_num + 1) * valid_species_num / 2
) # obtain 3 + 2 + 1 = 6
# calculate the spawn number rate by the rank of each species
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(
is_species_valid, spawn_number_rate, 0
) # set invalid species to 0
target_spawn_number = jnp.floor(
spawn_number_rate * self.pop_size
@@ -276,6 +281,41 @@ class SpeciesController(StatefulBaseClass):
return spawn_number
def _cal_spawn_numbers_by_fitness(self, species_state):
"""
decide the number of members of each species by their fitness.
the species with higher fitness will have more members
"""
species_keys = species_state.species_keys
# the fitness of each species
species_fitness = species_state.best_fitness
# calculate the spawn number rate by the fitness of each species
spawn_number_rate = species_fitness / jnp.sum(
species_fitness, where=~jnp.isnan(species_fitness)
)
target_spawn_number = jnp.floor(
spawn_number_rate * self.pop_size
) # calculate member
# Avoid too much variation of numbers for a species
previous_size = species_state.member_count
spawn_number = (
previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
@@ -503,7 +543,9 @@ class SpeciesController(StatefulBaseClass):
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(
state.species.best_fitness
)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
last_improved = jnp.where(
new_created_mask, state.generation, state.species.last_improved