FAST!
This commit is contained in:
@@ -5,6 +5,8 @@ import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome.utils import I_INT
|
||||
|
||||
|
||||
class Species(object):
|
||||
|
||||
@@ -12,7 +14,7 @@ class Species(object):
|
||||
self.key = key
|
||||
self.created = generation
|
||||
self.last_improved = generation
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
|
||||
self.members: NDArray = None # idx in pop_nodes, pop_connections,
|
||||
self.fitness = None
|
||||
self.member_fitnesses = None
|
||||
@@ -34,7 +36,7 @@ class SpeciesController:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||
|
||||
self.species_elitism = self.config.neat.species.species_elitism
|
||||
self.pop_size = self.config.neat.population.pop_size
|
||||
self.max_stagnation = self.config.neat.species.max_stagnation
|
||||
@@ -59,97 +61,7 @@ class SpeciesController:
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int,
|
||||
o2o_distance: Callable, o2m_distance: Callable) -> None:
|
||||
"""
|
||||
:param pop_nodes:
|
||||
:param pop_connections:
|
||||
:param generation: use to flag the created time of new species
|
||||
:param o2o_distance: distance function for one-to-one comparison
|
||||
:param o2m_distance: distance function for one-to-many comparison
|
||||
:return:
|
||||
"""
|
||||
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
|
||||
previous_species_list = list(self.species.keys())
|
||||
|
||||
# Find the best representatives for each existing species.
|
||||
new_representatives = {}
|
||||
new_members = {}
|
||||
|
||||
total_distances = jax.device_get([
|
||||
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
|
||||
for sid in previous_species_list
|
||||
])
|
||||
|
||||
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
|
||||
for i, sid in enumerate(previous_species_list):
|
||||
distances = total_distances[i]
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
new_representatives[sid] = min_idx
|
||||
new_members[sid] = [min_idx]
|
||||
unspeciated[min_idx] = False
|
||||
|
||||
# Partition population into species based on genetic similarity.
|
||||
|
||||
# First, fast match the population to previous species
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = jax.device_get([
|
||||
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
for rid in rid_list
|
||||
])
|
||||
|
||||
pop_res_distance = np.stack(res_pop_distance, axis=0).T
|
||||
for i in range(pop_res_distance.shape[0]):
|
||||
if not unspeciated[i]:
|
||||
continue
|
||||
min_idx = np.argmin(pop_res_distance[i])
|
||||
min_val = pop_res_distance[i, min_idx]
|
||||
if min_val <= self.compatibility_threshold:
|
||||
species_id = previous_species_list[min_idx]
|
||||
new_members[species_id].append(i)
|
||||
unspeciated[i] = False
|
||||
|
||||
# Second, slowly match the lonely population to new-created species.s
|
||||
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
|
||||
# the new representatives.
|
||||
for i in range(pop_nodes.shape[0]):
|
||||
if not unspeciated[i]:
|
||||
continue
|
||||
unspeciated[i] = False
|
||||
if len(new_representatives) != 0:
|
||||
# the representatives of new species
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
distances = jax.device_get([
|
||||
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
])
|
||||
distances = np.array(distances)
|
||||
min_idx = np.argmin(distances)
|
||||
min_val = distances[min_idx]
|
||||
if min_val <= self.compatibility_threshold:
|
||||
species_id = sid[min_idx]
|
||||
new_members[species_id].append(i)
|
||||
continue
|
||||
# create a new species
|
||||
species_id = next(self.species_idxer)
|
||||
new_representatives[species_id] = i
|
||||
new_members[species_id] = [i]
|
||||
|
||||
assert np.all(~unspeciated)
|
||||
|
||||
# Update species collection based on new speciation.
|
||||
for sid, rid in new_representatives.items():
|
||||
s = self.species.get(sid)
|
||||
if s is None:
|
||||
s = Species(sid, generation)
|
||||
self.species[sid] = s
|
||||
|
||||
members = np.array(new_members[sid])
|
||||
s.update((pop_nodes[rid], pop_connections[rid]), members)
|
||||
|
||||
def update_species_fitnesses(self, fitnesses):
|
||||
def __update_species_fitnesses(self, fitnesses):
|
||||
"""
|
||||
update the fitness of each species
|
||||
:param fitnesses:
|
||||
@@ -163,7 +75,7 @@ class SpeciesController:
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def stagnation(self, generation):
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param generation:
|
||||
@@ -196,7 +108,7 @@ class SpeciesController:
|
||||
result.append((sid, s, is_stagnant))
|
||||
return result
|
||||
|
||||
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param fitnesses:
|
||||
@@ -215,7 +127,7 @@ class SpeciesController:
|
||||
max_fitness = -np.inf
|
||||
|
||||
remaining_species = []
|
||||
for stag_sid, stag_s, stagnant in self.stagnation(generation):
|
||||
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
|
||||
if not stagnant:
|
||||
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
|
||||
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
|
||||
@@ -285,6 +197,33 @@ class SpeciesController:
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
if key not in self.species:
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, S, N, C):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
|
||||
pre_spe_center_cons = np.full((S, C, 4), np.nan)
|
||||
species_keys = np.full((S,), I_INT)
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
pre_spe_center_nodes[idx] = specie.representative[0]
|
||||
pre_spe_center_cons[idx] = specie.representative[1]
|
||||
species_keys[idx] = key
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
|
||||
pre_spe_center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
"""
|
||||
@@ -326,13 +265,7 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
return spawn_amounts
|
||||
|
||||
|
||||
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
|
||||
masked_arr = np.where(mask, arr, np.inf)
|
||||
min_idx = np.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
|
||||
-> Tuple[NDArray, NDArray]:
|
||||
sorted_idx = np.argsort(fitnesses)[::-1]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
|
||||
Reference in New Issue
Block a user