Next is to connect with Evox!
This commit is contained in:
wls2002
2023-06-25 02:57:45 +08:00
parent 0cb2f9473d
commit ba369db0b2
14 changed files with 392 additions and 268 deletions

View File

@@ -1,7 +1,15 @@
from typing import List, Tuple, Dict, Union, Callable
"""
Species Controller in NEAT.
The code are modified from neat-python.
See
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
"""
from typing import List, Tuple, Dict
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
@@ -37,14 +45,13 @@ class SpeciesController:
def __init__(self, config):
self.config = config
self.species_elitism = self.config.neat.species.species_elitism
self.pop_size = self.config.neat.population.pop_size
self.max_stagnation = self.config.neat.species.max_stagnation
self.min_species_size = self.config.neat.species.min_species_size
self.genome_elitism = self.config.neat.species.genome_elitism
self.survival_threshold = self.config.neat.species.survival_threshold
self.species_elitism = self.config['species_elitism']
self.pop_size = self.config['pop_size']
self.max_stagnation = self.config['max_stagnation']
self.min_species_size = self.config['min_species_size']
self.genome_elitism = self.config['genome_elitism']
self.survival_threshold = self.config['survival_threshold']
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
@@ -55,9 +62,10 @@ class SpeciesController:
:return:
"""
pop_size = pop_nodes.shape[0]
species_id = next(self.species_idxer)
species_id = 0 # the first species
s = Species(species_id, 0)
members = np.array(list(range(pop_size)))
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
@@ -68,16 +76,14 @@ class SpeciesController:
:return:
"""
for sid, s in self.species.items():
# TODO: here use mean to measure the fitness of a species, but it may be other functions
s.member_fitnesses = s.get_fitnesses(fitnesses)
# s.fitness = np.mean(s.member_fitnesses)
# use the max score to represent the fitness of the species
s.fitness = np.max(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def __stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
:return: whether the species is stagnated
"""
@@ -88,7 +94,7 @@ class SpeciesController:
else:
prev_fitness = float('-inf')
if prev_fitness is None or s.fitness > prev_fitness:
if s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
@@ -110,7 +116,6 @@ class SpeciesController:
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
code modified from neat-python!
:param fitnesses:
:param generation:
:return: crossover_pair for next generation.
@@ -136,6 +141,8 @@ class SpeciesController:
# No species left.
assert remaining_species
# TODO: Too complex!
# Compute each species' member size in the next generation.
# Do not allow the fitness range to be zero, as we divide by it below.
@@ -185,6 +192,7 @@ class SpeciesController:
# only use good genomes to crossover
sorted_members = sorted_members[:repro_cutoff]
# TODO: Genome with higher fitness should be more likely to be selected?
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
part1.extend(sorted_members[list_idx1])
part2.extend(sorted_members[list_idx2])
@@ -197,32 +205,37 @@ class SpeciesController:
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
# the new specie created in this generation
s = Species(key, generation)
self.species[key] = s
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
def ask(self, fitnesses, generation, S, N, C):
def ask(self, fitnesses, generation, symbols):
self.__update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
pre_spe_center_cons = np.full((S, C, 4), np.nan)
species_keys = np.full((S,), I_INT)
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
species_keys = np.full((symbols['S'], ), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
pre_spe_center_nodes[idx] = specie.representative[0]
pre_spe_center_cons[idx] = specie.representative[1]
center_nodes[idx], center_cons[idx] = specie.representative
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
pre_spe_center_cons, species_keys, next_new_specie_key
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):