add species_number_calculate_by params in neat.py
according to (issue)[https://github.com/EMI-Group/tensorneat/issues/8]
This commit is contained in:
@@ -24,7 +24,14 @@ class NEAT(BaseAlgorithm):
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 2.0,
|
||||
species_fitness_func: Callable = jnp.max,
|
||||
species_number_calculate_by: str = "rank",
|
||||
):
|
||||
|
||||
assert species_number_calculate_by in [
|
||||
"rank",
|
||||
"fitness",
|
||||
], "species_number_calculate_by should be either 'rank' or 'fitness'"
|
||||
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_controller = SpeciesController(
|
||||
@@ -38,6 +45,7 @@ class NEAT(BaseAlgorithm):
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
species_number_calculate_by,
|
||||
)
|
||||
|
||||
def setup(self, state=State()):
|
||||
|
||||
@@ -26,6 +26,7 @@ class SpeciesController(StatefulBaseClass):
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
species_number_calculate_by,
|
||||
):
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
@@ -38,6 +39,7 @@ class SpeciesController(StatefulBaseClass):
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.species_fitness_func = species_fitness_func
|
||||
self.species_number_calculate_by = species_number_calculate_by
|
||||
|
||||
def setup(self, state, first_nodes, first_conns):
|
||||
# the unique index (primary key) for each species
|
||||
@@ -111,7 +113,12 @@ class SpeciesController(StatefulBaseClass):
|
||||
)
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = self._cal_spawn_numbers(species_state)
|
||||
if self.species_number_calculate_by == "rank":
|
||||
spawn_number = self._cal_spawn_numbers_by_rank(species_state)
|
||||
elif self.species_number_calculate_by == "fitness":
|
||||
spawn_number = self._cal_spawn_numbers_by_fitness(species_state)
|
||||
else:
|
||||
raise ValueError("species_number_calculate_by must be 'rank' or 'fitness'")
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
@@ -234,7 +241,7 @@ class SpeciesController(StatefulBaseClass):
|
||||
species_fitness,
|
||||
)
|
||||
|
||||
def _cal_spawn_numbers(self, species_state):
|
||||
def _cal_spawn_numbers_by_rank(self, species_state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
@@ -250,11 +257,9 @@ class SpeciesController(StatefulBaseClass):
|
||||
(valid_species_num + 1) * valid_species_num / 2
|
||||
) # obtain 3 + 2 + 1 = 6
|
||||
|
||||
# calculate the spawn number rate by the rank of each species
|
||||
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(
|
||||
is_species_valid, spawn_number_rate, 0
|
||||
) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(
|
||||
spawn_number_rate * self.pop_size
|
||||
@@ -276,6 +281,41 @@ class SpeciesController(StatefulBaseClass):
|
||||
|
||||
return spawn_number
|
||||
|
||||
def _cal_spawn_numbers_by_fitness(self, species_state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness.
|
||||
the species with higher fitness will have more members
|
||||
"""
|
||||
|
||||
species_keys = species_state.species_keys
|
||||
|
||||
# the fitness of each species
|
||||
species_fitness = species_state.best_fitness
|
||||
|
||||
# calculate the spawn number rate by the fitness of each species
|
||||
spawn_number_rate = species_fitness / jnp.sum(
|
||||
species_fitness, where=~jnp.isnan(species_fitness)
|
||||
)
|
||||
target_spawn_number = jnp.floor(
|
||||
spawn_number_rate * self.pop_size
|
||||
) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers for a species
|
||||
previous_size = species_state.member_count
|
||||
spawn_number = (
|
||||
previous_size
|
||||
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
)
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = self.pop_size - jnp.sum(spawn_number)
|
||||
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
|
||||
return spawn_number
|
||||
|
||||
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
||||
s_idx = self.species_arange
|
||||
p_idx = jnp.arange(self.pop_size)
|
||||
@@ -503,7 +543,9 @@ class SpeciesController(StatefulBaseClass):
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(
|
||||
state.species.best_fitness
|
||||
)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
||||
last_improved = jnp.where(
|
||||
new_created_mask, state.generation, state.species.last_improved
|
||||
|
||||
Reference in New Issue
Block a user