Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -1,7 +1,15 @@
|
||||
from typing import List, Tuple, Dict, Union, Callable
|
||||
"""
|
||||
Species Controller in NEAT.
|
||||
The code are modified from neat-python.
|
||||
See
|
||||
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
from itertools import count
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
@@ -37,14 +45,13 @@ class SpeciesController:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.species_elitism = self.config.neat.species.species_elitism
|
||||
self.pop_size = self.config.neat.population.pop_size
|
||||
self.max_stagnation = self.config.neat.species.max_stagnation
|
||||
self.min_species_size = self.config.neat.species.min_species_size
|
||||
self.genome_elitism = self.config.neat.species.genome_elitism
|
||||
self.survival_threshold = self.config.neat.species.survival_threshold
|
||||
self.species_elitism = self.config['species_elitism']
|
||||
self.pop_size = self.config['pop_size']
|
||||
self.max_stagnation = self.config['max_stagnation']
|
||||
self.min_species_size = self.config['min_species_size']
|
||||
self.genome_elitism = self.config['genome_elitism']
|
||||
self.survival_threshold = self.config['survival_threshold']
|
||||
|
||||
self.species_idxer = count(0)
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
||||
@@ -55,9 +62,10 @@ class SpeciesController:
|
||||
:return:
|
||||
"""
|
||||
pop_size = pop_nodes.shape[0]
|
||||
species_id = next(self.species_idxer)
|
||||
species_id = 0 # the first species
|
||||
s = Species(species_id, 0)
|
||||
members = np.array(list(range(pop_size)))
|
||||
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
@@ -68,16 +76,14 @@ class SpeciesController:
|
||||
:return:
|
||||
"""
|
||||
for sid, s in self.species.items():
|
||||
# TODO: here use mean to measure the fitness of a species, but it may be other functions
|
||||
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||
# s.fitness = np.mean(s.member_fitnesses)
|
||||
# use the max score to represent the fitness of the species
|
||||
s.fitness = np.max(s.member_fitnesses)
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param generation:
|
||||
:return: whether the species is stagnated
|
||||
"""
|
||||
@@ -88,7 +94,7 @@ class SpeciesController:
|
||||
else:
|
||||
prev_fitness = float('-inf')
|
||||
|
||||
if prev_fitness is None or s.fitness > prev_fitness:
|
||||
if s.fitness > prev_fitness:
|
||||
s.last_improved = generation
|
||||
|
||||
species_data.append((sid, s))
|
||||
@@ -110,7 +116,6 @@ class SpeciesController:
|
||||
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param fitnesses:
|
||||
:param generation:
|
||||
:return: crossover_pair for next generation.
|
||||
@@ -136,6 +141,8 @@ class SpeciesController:
|
||||
# No species left.
|
||||
assert remaining_species
|
||||
|
||||
|
||||
# TODO: Too complex!
|
||||
# Compute each species' member size in the next generation.
|
||||
|
||||
# Do not allow the fitness range to be zero, as we divide by it below.
|
||||
@@ -185,6 +192,7 @@ class SpeciesController:
|
||||
# only use good genomes to crossover
|
||||
sorted_members = sorted_members[:repro_cutoff]
|
||||
|
||||
# TODO: Genome with higher fitness should be more likely to be selected?
|
||||
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
||||
part1.extend(sorted_members[list_idx1])
|
||||
part2.extend(sorted_members[list_idx2])
|
||||
@@ -197,32 +205,37 @@ class SpeciesController:
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
|
||||
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
|
||||
if key not in self.species:
|
||||
# the new specie created in this generation
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
|
||||
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, S, N, C):
|
||||
def ask(self, fitnesses, generation, symbols):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
|
||||
pre_spe_center_cons = np.full((S, C, 4), np.nan)
|
||||
species_keys = np.full((S,), I_INT)
|
||||
|
||||
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
|
||||
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
|
||||
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
|
||||
species_keys = np.full((symbols['S'], ), I_INT)
|
||||
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
pre_spe_center_nodes[idx] = specie.representative[0]
|
||||
pre_spe_center_cons[idx] = specie.representative[1]
|
||||
center_nodes[idx], center_cons[idx] = specie.representative
|
||||
species_keys[idx] = key
|
||||
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
|
||||
pre_spe_center_cons, species_keys, next_new_specie_key
|
||||
|
||||
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
|
||||
Reference in New Issue
Block a user