initial commit

This commit is contained in:
wls2002
2023-05-05 14:19:13 +08:00
commit 6faa07f507
43 changed files with 2517 additions and 0 deletions

190
algorithms/neat/species.py Normal file
View File

@@ -0,0 +1,190 @@
from typing import List, Tuple, Dict
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
class Species(object):
def __init__(self, key, generation):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
self.members: List[int] = [] # idx in pop_nodes, pop_connections
self.fitness = None
self.member_fitnesses = None
self.adjusted_fitness = None
self.fitness_history: List[float] = []
def update(self, representative, members):
self.representative = representative
self.members = members
def get_fitnesses(self, fitnesses):
return [fitnesses[m] for m in self.members]
class SpeciesController:
"""
A class to control the species
"""
def __init__(self, config):
self.config = config
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.species_elitism = self.config.neat.species.species_elitism
self.max_stagnation = self.config.neat.species.max_stagnation
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.genome_to_species: Dict[int, int] = {}
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
# self.o2o_distance_func = np_distance # one to one
self.o2o_distance_func = distance
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
previous_species_list = list(self.species.keys())
# Find the best representatives for each existing species.
new_representatives = {}
new_members = {}
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
new_members[sid] = [min_idx]
unspeciated[min_idx] = False
# Partition population into species based on genetic similarity.
# First, fast match the population to previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
[
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
)
]
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]:
continue
min_idx = np.argmin(pop_res_distance[i])
min_val = pop_res_distance[i, min_idx]
if min_val <= self.compatibility_threshold:
species_id = previous_species_list[min_idx]
new_members[species_id].append(i)
unspeciated[i] = False
# Second, slowly match the lonely population to new-created species.
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives.
new_species_list = []
for i in range(pop_nodes.shape[0]):
if not unspeciated[i]:
continue
unspeciated[i] = False
if len(new_representatives) != 0:
rid = [new_representatives[sid] for sid in new_representatives] # the representatives of new species
distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]
if min_val <= self.compatibility_threshold:
species_id = new_species_list[min_idx]
new_members[species_id].append(i)
continue
# create a new species
species_id = next(self.species_idxer)
new_species_list.append(species_id)
new_representatives[species_id] = i
new_members[species_id] = [i]
assert np.all(~unspeciated)
# Update species collection based on new speciation.
self.genome_to_species = {}
for sid, rid in new_representatives.items():
s = self.species.get(sid)
if s is None:
s = Species(sid, generation)
self.species[sid] = s
members = new_members[sid]
for gid in members:
self.genome_to_species[gid] = sid
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
:return:
"""
for sid, s in self.species.items():
# TODO: here use mean to measure the fitness of a species, but it may be other functions
s.member_fitnesses = s.get_fitnesses(fitnesses)
s.fitness = np.mean(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
:return: whether the species is stagnated
"""
species_data = []
for sid, s in self.species.items():
if s.fitness_history:
prev_fitness = max(s.fitness_history)
else:
prev_fitness = float('-inf')
if prev_fitness is None or s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
# Sort in descending fitness order.
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
result = []
for idx, (sid, s) in enumerate(species_data):
if idx < self.species_elitism: # elitism species never stagnate!
is_stagnant = False
else:
stagnant_time = generation - s.last_improved
is_stagnant = stagnant_time > self.max_stagnation
result.append((sid, s, is_stagnant))
return result
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
masked_arr = np.where(mask, arr, np.inf)
min_idx = np.argmin(masked_arr)
return min_idx