add species_number_calculate_by params in neat.py
according to (issue)[https://github.com/EMI-Group/tensorneat/issues/8]
This commit is contained in:
@@ -24,7 +24,14 @@ class NEAT(BaseAlgorithm):
|
|||||||
min_species_size: int = 1,
|
min_species_size: int = 1,
|
||||||
compatibility_threshold: float = 2.0,
|
compatibility_threshold: float = 2.0,
|
||||||
species_fitness_func: Callable = jnp.max,
|
species_fitness_func: Callable = jnp.max,
|
||||||
|
species_number_calculate_by: str = "rank",
|
||||||
):
|
):
|
||||||
|
|
||||||
|
assert species_number_calculate_by in [
|
||||||
|
"rank",
|
||||||
|
"fitness",
|
||||||
|
], "species_number_calculate_by should be either 'rank' or 'fitness'"
|
||||||
|
|
||||||
self.genome = genome
|
self.genome = genome
|
||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
self.species_controller = SpeciesController(
|
self.species_controller = SpeciesController(
|
||||||
@@ -38,6 +45,7 @@ class NEAT(BaseAlgorithm):
|
|||||||
min_species_size,
|
min_species_size,
|
||||||
compatibility_threshold,
|
compatibility_threshold,
|
||||||
species_fitness_func,
|
species_fitness_func,
|
||||||
|
species_number_calculate_by,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
min_species_size,
|
min_species_size,
|
||||||
compatibility_threshold,
|
compatibility_threshold,
|
||||||
species_fitness_func,
|
species_fitness_func,
|
||||||
|
species_number_calculate_by,
|
||||||
):
|
):
|
||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
self.species_size = species_size
|
self.species_size = species_size
|
||||||
@@ -38,6 +39,7 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
self.min_species_size = min_species_size
|
self.min_species_size = min_species_size
|
||||||
self.compatibility_threshold = compatibility_threshold
|
self.compatibility_threshold = compatibility_threshold
|
||||||
self.species_fitness_func = species_fitness_func
|
self.species_fitness_func = species_fitness_func
|
||||||
|
self.species_number_calculate_by = species_number_calculate_by
|
||||||
|
|
||||||
def setup(self, state, first_nodes, first_conns):
|
def setup(self, state, first_nodes, first_conns):
|
||||||
# the unique index (primary key) for each species
|
# the unique index (primary key) for each species
|
||||||
@@ -111,7 +113,12 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# decide the number of members of each species by their fitness
|
# decide the number of members of each species by their fitness
|
||||||
spawn_number = self._cal_spawn_numbers(species_state)
|
if self.species_number_calculate_by == "rank":
|
||||||
|
spawn_number = self._cal_spawn_numbers_by_rank(species_state)
|
||||||
|
elif self.species_number_calculate_by == "fitness":
|
||||||
|
spawn_number = self._cal_spawn_numbers_by_fitness(species_state)
|
||||||
|
else:
|
||||||
|
raise ValueError("species_number_calculate_by must be 'rank' or 'fitness'")
|
||||||
|
|
||||||
k1, k2 = jax.random.split(state.randkey)
|
k1, k2 = jax.random.split(state.randkey)
|
||||||
# crossover info
|
# crossover info
|
||||||
@@ -234,7 +241,7 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
species_fitness,
|
species_fitness,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _cal_spawn_numbers(self, species_state):
|
def _cal_spawn_numbers_by_rank(self, species_state):
|
||||||
"""
|
"""
|
||||||
decide the number of members of each species by their fitness rank.
|
decide the number of members of each species by their fitness rank.
|
||||||
the species with higher fitness will have more members
|
the species with higher fitness will have more members
|
||||||
@@ -250,11 +257,9 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
(valid_species_num + 1) * valid_species_num / 2
|
(valid_species_num + 1) * valid_species_num / 2
|
||||||
) # obtain 3 + 2 + 1 = 6
|
) # obtain 3 + 2 + 1 = 6
|
||||||
|
|
||||||
|
# calculate the spawn number rate by the rank of each species
|
||||||
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
||||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||||
spawn_number_rate = jnp.where(
|
|
||||||
is_species_valid, spawn_number_rate, 0
|
|
||||||
) # set invalid species to 0
|
|
||||||
|
|
||||||
target_spawn_number = jnp.floor(
|
target_spawn_number = jnp.floor(
|
||||||
spawn_number_rate * self.pop_size
|
spawn_number_rate * self.pop_size
|
||||||
@@ -276,6 +281,41 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
|
|
||||||
return spawn_number
|
return spawn_number
|
||||||
|
|
||||||
|
def _cal_spawn_numbers_by_fitness(self, species_state):
|
||||||
|
"""
|
||||||
|
decide the number of members of each species by their fitness.
|
||||||
|
the species with higher fitness will have more members
|
||||||
|
"""
|
||||||
|
|
||||||
|
species_keys = species_state.species_keys
|
||||||
|
|
||||||
|
# the fitness of each species
|
||||||
|
species_fitness = species_state.best_fitness
|
||||||
|
|
||||||
|
# calculate the spawn number rate by the fitness of each species
|
||||||
|
spawn_number_rate = species_fitness / jnp.sum(
|
||||||
|
species_fitness, where=~jnp.isnan(species_fitness)
|
||||||
|
)
|
||||||
|
target_spawn_number = jnp.floor(
|
||||||
|
spawn_number_rate * self.pop_size
|
||||||
|
) # calculate member
|
||||||
|
|
||||||
|
# Avoid too much variation of numbers for a species
|
||||||
|
previous_size = species_state.member_count
|
||||||
|
spawn_number = (
|
||||||
|
previous_size
|
||||||
|
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||||
|
)
|
||||||
|
spawn_number = spawn_number.astype(jnp.int32)
|
||||||
|
|
||||||
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
|
error = self.pop_size - jnp.sum(spawn_number)
|
||||||
|
|
||||||
|
# add error to the first species to control the sum of spawn_number
|
||||||
|
spawn_number = spawn_number.at[0].add(error)
|
||||||
|
|
||||||
|
return spawn_number
|
||||||
|
|
||||||
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
||||||
s_idx = self.species_arange
|
s_idx = self.species_arange
|
||||||
p_idx = jnp.arange(self.pop_size)
|
p_idx = jnp.arange(self.pop_size)
|
||||||
@@ -503,7 +543,9 @@ class SpeciesController(StatefulBaseClass):
|
|||||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||||
|
|
||||||
# complete info of species which is created in this generation
|
# complete info of species which is created in this generation
|
||||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
|
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(
|
||||||
|
state.species.best_fitness
|
||||||
|
)
|
||||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
||||||
last_improved = jnp.where(
|
last_improved = jnp.where(
|
||||||
new_created_mask, state.generation, state.species.last_improved
|
new_created_mask, state.generation, state.species.last_improved
|
||||||
|
|||||||
Reference in New Issue
Block a user