This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -5,6 +5,8 @@ import jax
import numpy as np
from numpy.typing import NDArray
from .genome.utils import I_INT
class Species(object):
@@ -12,7 +14,7 @@ class Species(object):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None
self.member_fitnesses = None
@@ -34,7 +36,7 @@ class SpeciesController:
def __init__(self, config):
self.config = config
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.species_elitism = self.config.neat.species.species_elitism
self.pop_size = self.config.neat.population.pop_size
self.max_stagnation = self.config.neat.species.max_stagnation
@@ -59,97 +61,7 @@ class SpeciesController:
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int,
o2o_distance: Callable, o2m_distance: Callable) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:param o2o_distance: distance function for one-to-one comparison
:param o2m_distance: distance function for one-to-many comparison
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
previous_species_list = list(self.species.keys())
# Find the best representatives for each existing species.
new_representatives = {}
new_members = {}
total_distances = jax.device_get([
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
for sid in previous_species_list
])
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
new_members[sid] = [min_idx]
unspeciated[min_idx] = False
# Partition population into species based on genetic similarity.
# First, fast match the population to previous species
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = jax.device_get([
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
])
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]:
continue
min_idx = np.argmin(pop_res_distance[i])
min_val = pop_res_distance[i, min_idx]
if min_val <= self.compatibility_threshold:
species_id = previous_species_list[min_idx]
new_members[species_id].append(i)
unspeciated[i] = False
# Second, slowly match the lonely population to new-created species.s
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives.
for i in range(pop_nodes.shape[0]):
if not unspeciated[i]:
continue
unspeciated[i] = False
if len(new_representatives) != 0:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = jax.device_get([
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
])
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]
if min_val <= self.compatibility_threshold:
species_id = sid[min_idx]
new_members[species_id].append(i)
continue
# create a new species
species_id = next(self.species_idxer)
new_representatives[species_id] = i
new_members[species_id] = [i]
assert np.all(~unspeciated)
# Update species collection based on new speciation.
for sid, rid in new_representatives.items():
s = self.species.get(sid)
if s is None:
s = Species(sid, generation)
self.species[sid] = s
members = np.array(new_members[sid])
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
def __update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
@@ -163,7 +75,7 @@ class SpeciesController:
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def stagnation(self, generation):
def __stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
@@ -196,7 +108,7 @@ class SpeciesController:
result.append((sid, s, is_stagnant))
return result
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
code modified from neat-python!
:param fitnesses:
@@ -215,7 +127,7 @@ class SpeciesController:
max_fitness = -np.inf
remaining_species = []
for stag_sid, stag_s, stagnant in self.stagnation(generation):
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
if not stagnant:
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
@@ -285,6 +197,33 @@ class SpeciesController:
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
s = Species(key, generation)
self.species[key] = s
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
def ask(self, fitnesses, generation, S, N, C):
self.__update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
pre_spe_center_cons = np.full((S, C, 4), np.nan)
species_keys = np.full((S,), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
pre_spe_center_nodes[idx] = specie.representative[0]
pre_spe_center_cons[idx] = specie.representative[1]
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
pre_spe_center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
@@ -326,13 +265,7 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
return spawn_amounts
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
masked_arr = np.where(mask, arr, np.inf)
min_idx = np.argmin(masked_arr)
return min_idx
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1]
return members[sorted_idx], fitnesses[sorted_idx]
return members[sorted_idx], fitnesses[sorted_idx]