The whole NEAT algorithm is written into functional programming.
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
|
||||
from .operations import create_next_generation_then_speciate
|
||||
from .species import SpeciesController
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections
|
||||
from .population import update_species, create_next_generation, speciate
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .forward import create_forward
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
from .forward import create_forward_function
|
||||
|
||||
@@ -5,7 +5,7 @@ from jax import jit, vmap
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward(config):
|
||||
def create_forward_function(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
@@ -83,4 +83,22 @@ def create_forward(config):
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
if config['forward_way'] == 'single':
|
||||
return jit(batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'pop':
|
||||
return jit(pop_batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return forward
|
||||
|
||||
@@ -65,55 +65,6 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand a single genome to accommodate more nodes or connections.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (new_N, 5), (new_C, 4)
|
||||
"""
|
||||
old_N, old_C = nodes.shape[0], cons.shape[0]
|
||||
new_nodes = np.full((new_N, 5), np.nan)
|
||||
new_nodes[:old_N, :] = nodes
|
||||
|
||||
new_cons = np.full((new_C, 4), np.nan)
|
||||
new_cons[:old_C, :] = cons
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand the population to accommodate more nodes or connections.
|
||||
:param pop_nodes: (pop_size, N, 5)
|
||||
:param pop_cons: (pop_size, C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
|
||||
"""
|
||||
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
||||
|
||||
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||
|
||||
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
||||
new_pop_cons[:, :old_C, :] = pop_cons
|
||||
|
||||
return new_pop_nodes, new_pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Count how many nodes and connections are in the genome.
|
||||
"""
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
|
||||
|
||||
@@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if reverse:
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
@@ -1,160 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import jit, numpy as jnp, vmap
|
||||
|
||||
from .genome.utils import rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
randkey: random key
|
||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||
species_keys: Array[(species_size, 3), float], the information of each species
|
||||
[species_key, best_score, last_update]
|
||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||
generation: int, current generation
|
||||
jit_config: Dict, the configuration of jit functions
|
||||
"""
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = update_species_fitness(species_info, idx2species, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_fitness, species_info, center_nodes, center_cons = \
|
||||
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
|
||||
|
||||
# sort species_info by their fitness. (push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
species_info = species_info[sort_indices]
|
||||
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||
|
||||
# crossover info
|
||||
winner, loser, elite_mask = \
|
||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||
|
||||
jax.debug.print("{}, {}", fitness, winner)
|
||||
jax.debug.print("{}", fitness[winner])
|
||||
|
||||
return species_info, center_nodes, center_cons, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(species_info, idx2species, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
species_key = species_info[idx, 0]
|
||||
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
|
||||
f = jnp.max(s_fitness)
|
||||
return f
|
||||
|
||||
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = species_fitness[idx]
|
||||
species_key, best_score, last_update = species_info[idx]
|
||||
# stagnation condition
|
||||
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||
|
||||
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
species_info = jnp.where(st[:, None], jnp.nan, species_info)
|
||||
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
|
||||
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
|
||||
species_fitness = jnp.where(st, jnp.nan, species_fitness)
|
||||
|
||||
return species_fitness, species_info, center_nodes, center_cons
|
||||
|
||||
|
||||
def cal_spawn_numbers(species_info, jit_config):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_info[:, 0])
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||
|
||||
species_size = species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
s_idx = jnp.arange(species_size)
|
||||
p_idx = jnp.arange(pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
members = idx2species == species_info[idx, 0]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, jnp.nan)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
elite_size = jit_config['genome_elitism']
|
||||
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < elite_size, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
||||
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
"""
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
|
||||
center_nodes, center_cons, species_keys, new_species_key_start,
|
||||
jit_config):
|
||||
# create next generation
|
||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
|
||||
new_node_keys, jit_config)
|
||||
|
||||
# speciate
|
||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
|
||||
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
|
||||
|
||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
|
||||
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
|
||||
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# batch mutation
|
||||
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
|
||||
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, cn, cc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_keys[i])
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
return i2s, cn, cc
|
||||
|
||||
idx2specie, center_nodes, center_cons = \
|
||||
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||
sk[i] == I_INT, # whether the current species is existing or not
|
||||
create_new_specie, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, sk, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, sk, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to the new species
|
||||
sk = sk.at[i].set(ck)
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
|
||||
return i2s, cn, cc, sk, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, sk = carry
|
||||
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||
return i2s
|
||||
|
||||
current_new_key = new_species_key_start
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
@@ -3,13 +3,13 @@ from typing import Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit, vmap
|
||||
|
||||
from configs import Configer
|
||||
from function_factory import FunctionFactory
|
||||
from algorithms.neat import initialize_genomes, expand, expand_single
|
||||
from algorithms.neat import initialize_genomes
|
||||
|
||||
from algorithms.neat.jit_species import update_species
|
||||
from algorithms.neat.operations import create_next_generation_then_speciate
|
||||
from algorithms.neat.population import create_next_generation, speciate, update_species
|
||||
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -17,30 +17,27 @@ class Pipeline:
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, function_factory=None, seed=42):
|
||||
def __init__(self, config, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||
|
||||
self.symbols = {
|
||||
'P': self.config['pop_size'],
|
||||
'N': self.config['init_maximum_nodes'],
|
||||
'C': self.config['init_maximum_connections'],
|
||||
'S': self.config['init_maximum_species'],
|
||||
}
|
||||
self.P = config['pop_size']
|
||||
self.N = config['init_maximum_nodes']
|
||||
self.C = config['init_maximum_connections']
|
||||
self.S = config['init_maximum_species']
|
||||
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||
self.species_info = np.full((self.symbols['S'], 3), np.nan)
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
|
||||
self.species_info = np.full((self.S, 3), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0
|
||||
self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32)
|
||||
self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan)
|
||||
self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan)
|
||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
||||
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
||||
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
|
||||
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
||||
|
||||
@@ -49,7 +46,10 @@ class Pipeline:
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
print(self.config)
|
||||
|
||||
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(topological_sort))
|
||||
self.forward = create_forward_function(config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
@@ -71,52 +71,28 @@ class Pipeline:
|
||||
e.g. numerical regression; Hyper-NEAT
|
||||
|
||||
"""
|
||||
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
|
||||
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
|
||||
u_pop_cons = self.pop_unflatten_connections(self.pop_nodes, self.pop_cons)
|
||||
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
|
||||
|
||||
if self.config['forward_way'] == 'single':
|
||||
forward_funcs = []
|
||||
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
|
||||
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
|
||||
forward_funcs.append(func)
|
||||
return forward_funcs
|
||||
|
||||
elif self.config['forward_way'] == 'pop':
|
||||
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
elif self.config['forward_way'] == 'common':
|
||||
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# only common mode is supported currently
|
||||
assert self.config['forward_way'] == 'common'
|
||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||
update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
||||
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
# node keys to be used in the mutation process
|
||||
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
|
||||
# create the next generation and then speciate the population
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||
center_cons, species_keys, species_key_start, self.jit_config)
|
||||
|
||||
# carry data to cpu
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||
|
||||
# update randkey
|
||||
self.randkey = jax.random.split(self.randkey)[0]
|
||||
|
||||
def get_func(self, name):
|
||||
return self.function_factory.get(name, self.symbols)
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
|
||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
||||
self.jit_config)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
307
algorithms/neat/population.py
Normal file
307
algorithms/neat/population.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
"""
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first, rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
randkey: random key
|
||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||
species_keys: Array[(species_size, 3), float], the information of each species
|
||||
[species_key, best_score, last_update]
|
||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||
generation: int, current generation
|
||||
jit_config: Dict, the configuration of jit functions
|
||||
"""
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = update_species_fitness(species_info, idx2species, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_fitness, species_info, center_nodes, center_cons = \
|
||||
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
|
||||
|
||||
# sort species_info by their fitness. (push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
species_info = species_info[sort_indices]
|
||||
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||
|
||||
# crossover info
|
||||
winner, loser, elite_mask = \
|
||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||
|
||||
return species_info, center_nodes, center_cons, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(species_info, idx2species, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
species_key = species_info[idx, 0]
|
||||
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
|
||||
f = jnp.max(s_fitness)
|
||||
return f
|
||||
|
||||
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = species_fitness[idx]
|
||||
species_key, best_score, last_update = species_info[idx]
|
||||
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||
last_update = jnp.where(s_fitness > best_score, generation, last_update)
|
||||
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
|
||||
# stagnation condition
|
||||
return st, jnp.array([species_key, best_score, last_update])
|
||||
|
||||
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
spe_st = jnp.where(species_rank < jit_config['species_elitism'], False, spe_st) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
|
||||
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
|
||||
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
|
||||
species_fitness = jnp.where(spe_st, jnp.nan, species_fitness)
|
||||
|
||||
return species_fitness, species_info, center_nodes, center_cons
|
||||
|
||||
|
||||
def cal_spawn_numbers(species_info, jit_config):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_info[:, 0])
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||
|
||||
species_size = species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
s_idx = jnp.arange(species_size)
|
||||
p_idx = jnp.arange(pop_size)
|
||||
|
||||
# def aux_func(key, idx):
|
||||
def aux_func(key, idx):
|
||||
members = idx2species == species_info[idx, 0]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
elite_size = jit_config['genome_elitism']
|
||||
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < elite_size, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
# fas, mas, elites = jax.lax.max(aux_func, (jax.random.split(randkey, species_size), s_idx))
|
||||
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
||||
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
new_node_keys = jnp.arange(pop_size) + generation * pop_size
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
|
||||
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
|
||||
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# batch mutation
|
||||
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
|
||||
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), jnp.nan) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, cn, cc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=jnp.isnan(i2s))
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(~jnp.isnan(species_info[i, 0]), idx, I_INT)
|
||||
i = jnp.where(~jnp.isnan(species_info[i, 0]), i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_info[i, 0])
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
return i2s, cn, cc
|
||||
|
||||
idx2specie, center_nodes, center_cons = \
|
||||
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, si, ck = carry # si is short for species_info, ck is short for current key
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, si, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
i2s, scn, scc, si, ck = jax.lax.cond(
|
||||
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
||||
create_new_specie, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, si, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, si, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, cn, cc, si, ck = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation]
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
|
||||
return i2s, cn, cc, si, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, si, ck = carry
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
|
||||
return i2s, cn, cc, si, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, si = carry
|
||||
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & jnp.isnan(i2s), si[i, 0], i2s)
|
||||
return i2s
|
||||
|
||||
species_keys = species_info[:, 0]
|
||||
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_info, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_info, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
|
||||
return idx2specie, center_nodes, center_cons, species_info
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
Species Controller in NEAT.
|
||||
The code are modified from neat-python.
|
||||
See
|
||||
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome.utils import I_INT
|
||||
|
||||
|
||||
class Species(object):
|
||||
|
||||
def __init__(self, key, generation):
|
||||
self.key = key
|
||||
self.created = generation
|
||||
self.last_improved = generation
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
|
||||
self.members: NDArray = None # idx in pop_nodes, pop_connections,
|
||||
self.fitness = None
|
||||
self.member_fitnesses = None
|
||||
self.adjusted_fitness = None
|
||||
self.fitness_history: List[float] = []
|
||||
|
||||
def update(self, representative, members):
|
||||
self.representative = representative
|
||||
self.members = members
|
||||
|
||||
def get_fitnesses(self, fitnesses):
|
||||
return fitnesses[self.members]
|
||||
|
||||
|
||||
class SpeciesController:
|
||||
"""
|
||||
A class to control the species
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.species_elitism = self.config['species_elitism']
|
||||
self.pop_size = self.config['pop_size']
|
||||
self.max_stagnation = self.config['max_stagnation']
|
||||
self.min_species_size = self.config['min_species_size']
|
||||
self.genome_elitism = self.config['genome_elitism']
|
||||
self.survival_threshold = self.config['survival_threshold']
|
||||
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
||||
"""
|
||||
speciate for the first generation
|
||||
:param pop_connections:
|
||||
:param pop_nodes:
|
||||
:return:
|
||||
"""
|
||||
pop_size = pop_nodes.shape[0]
|
||||
species_id = 0 # the first species
|
||||
s = Species(species_id, 0)
|
||||
members = np.array(list(range(pop_size)))
|
||||
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
def __update_species_fitnesses(self, fitnesses):
|
||||
"""
|
||||
update the fitness of each species
|
||||
:param fitnesses:
|
||||
:return:
|
||||
"""
|
||||
for sid, s in self.species.items():
|
||||
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||
# use the max score to represent the fitness of the species
|
||||
s.fitness = np.max(s.member_fitnesses)
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
:param generation:
|
||||
:return: whether the species is stagnated
|
||||
"""
|
||||
species_data = []
|
||||
for sid, s in self.species.items():
|
||||
if s.fitness_history:
|
||||
prev_fitness = max(s.fitness_history)
|
||||
else:
|
||||
prev_fitness = float('-inf')
|
||||
|
||||
if s.fitness > prev_fitness:
|
||||
s.last_improved = generation
|
||||
|
||||
species_data.append((sid, s))
|
||||
|
||||
# Sort in descending fitness order.
|
||||
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
|
||||
|
||||
result = []
|
||||
for idx, (sid, s) in enumerate(species_data):
|
||||
|
||||
if idx < self.species_elitism: # elitism species never stagnate!
|
||||
is_stagnant = False
|
||||
else:
|
||||
stagnant_time = generation - s.last_improved
|
||||
is_stagnant = stagnant_time > self.max_stagnation
|
||||
|
||||
result.append((sid, s, is_stagnant))
|
||||
return result
|
||||
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
:param fitnesses:
|
||||
:param generation:
|
||||
:return: crossover_pair for next generation.
|
||||
# int -> idx in the pop_nodes, pop_connections of elitism
|
||||
# (int, int) -> the father and mother idx to be crossover
|
||||
"""
|
||||
# Filter out stagnated species, collect the set of non-stagnated
|
||||
# species members, and compute their average adjusted fitness.
|
||||
# The average adjusted fitness scheme (normalized to the interval
|
||||
# [0, 1]) allows the use of negative fitness values without
|
||||
# interfering with the shared fitness scheme.
|
||||
|
||||
min_fitness = np.inf
|
||||
max_fitness = -np.inf
|
||||
|
||||
remaining_species = []
|
||||
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
|
||||
if not stagnant:
|
||||
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
|
||||
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
|
||||
remaining_species.append(stag_s)
|
||||
|
||||
# No species left.
|
||||
assert remaining_species
|
||||
|
||||
|
||||
# TODO: Too complex!
|
||||
# Compute each species' member size in the next generation.
|
||||
|
||||
# Do not allow the fitness range to be zero, as we divide by it below.
|
||||
# TODO: The ``1.0`` below is rather arbitrary, and should be configurable.
|
||||
fitness_range = max(1.0, max_fitness - min_fitness)
|
||||
for afs in remaining_species:
|
||||
# Compute adjusted fitness.
|
||||
msf = afs.fitness
|
||||
af = (msf - min_fitness) / fitness_range # make adjusted fitness in [0, 1]
|
||||
afs.adjusted_fitness = af
|
||||
adjusted_fitnesses = [s.adjusted_fitness for s in remaining_species]
|
||||
previous_sizes = [len(s.members) for s in remaining_species]
|
||||
min_species_size = max(self.min_species_size, self.genome_elitism)
|
||||
spawn_amounts = compute_spawn(adjusted_fitnesses, previous_sizes, self.pop_size, min_species_size)
|
||||
assert sum(spawn_amounts) == self.pop_size
|
||||
|
||||
# generate new population and speciate
|
||||
self.species = {}
|
||||
# int -> idx in the pop_nodes, pop_connections of elitism
|
||||
# (int, int) -> the father and mother idx to be crossover
|
||||
|
||||
part1, part2, elite_mask = [], [], []
|
||||
|
||||
for spawn, s in zip(spawn_amounts, remaining_species):
|
||||
assert spawn >= self.genome_elitism
|
||||
|
||||
# retain remain species to next generation
|
||||
old_members, member_fitnesses = s.members, s.member_fitnesses
|
||||
s.members = []
|
||||
self.species[s.key] = s
|
||||
|
||||
# add elitism genomes to next generation
|
||||
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses)
|
||||
if self.genome_elitism > 0:
|
||||
for m in sorted_members[:self.genome_elitism]:
|
||||
part1.append(m)
|
||||
part2.append(m)
|
||||
elite_mask.append(True)
|
||||
spawn -= 1
|
||||
|
||||
if spawn <= 0:
|
||||
continue
|
||||
|
||||
# add genome to be crossover to next generation
|
||||
repro_cutoff = int(np.ceil(self.survival_threshold * len(sorted_members)))
|
||||
repro_cutoff = max(repro_cutoff, 2)
|
||||
# only use good genomes to crossover
|
||||
sorted_members = sorted_members[:repro_cutoff]
|
||||
|
||||
# TODO: Genome with higher fitness should be more likely to be selected?
|
||||
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
||||
part1.extend(sorted_members[list_idx1])
|
||||
part2.extend(sorted_members[list_idx2])
|
||||
elite_mask.extend([False] * spawn)
|
||||
|
||||
part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2]
|
||||
is_part1_win = part1_fitness >= part2_fitness
|
||||
winner_part = np.where(is_part1_win, part1, part2)
|
||||
loser_part = np.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
|
||||
if key not in self.species:
|
||||
# the new specie created in this generation
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, symbols):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
|
||||
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
|
||||
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
|
||||
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
|
||||
species_keys = np.full((symbols['S'], ), I_INT)
|
||||
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
center_nodes[idx], center_cons[idx] = specie.representative
|
||||
species_keys[idx] = key
|
||||
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
|
||||
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
"""
|
||||
Code from neat-python, the only modification is to fix the population size for each generation.
|
||||
Compute the proper number of offspring per species (proportional to fitness).
|
||||
"""
|
||||
af_sum = sum(adjusted_fitness)
|
||||
|
||||
spawn_amounts = []
|
||||
for af, ps in zip(adjusted_fitness, previous_sizes):
|
||||
if af_sum > 0:
|
||||
s = max(min_species_size, af / af_sum * pop_size)
|
||||
else:
|
||||
s = min_species_size
|
||||
|
||||
d = (s - ps) * 0.5
|
||||
c = int(round(d))
|
||||
spawn = ps
|
||||
if abs(c) > 0:
|
||||
spawn += c
|
||||
elif d > 0:
|
||||
spawn += 1
|
||||
elif d < 0:
|
||||
spawn -= 1
|
||||
|
||||
spawn_amounts.append(spawn)
|
||||
|
||||
# Normalize the spawn amounts so that the next generation is roughly
|
||||
# the population size requested by the user.
|
||||
total_spawn = sum(spawn_amounts)
|
||||
norm = pop_size / total_spawn
|
||||
spawn_amounts = [max(min_species_size, int(round(n * norm))) for n in spawn_amounts]
|
||||
|
||||
# for batch parallelization, pop size must be a fixed value.
|
||||
total_amounts = sum(spawn_amounts)
|
||||
spawn_amounts[0] += pop_size - total_amounts
|
||||
assert sum(spawn_amounts) == pop_size, "Population size is not stable."
|
||||
|
||||
return spawn_amounts
|
||||
|
||||
|
||||
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
|
||||
-> Tuple[NDArray, NDArray]:
|
||||
sorted_idx = np.argsort(fitnesses)[::-1]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
@@ -1,7 +1,7 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
init_maximum_nodes = 200
|
||||
init_maximum_nodes = 50
|
||||
init_maximum_connections = 200
|
||||
init_maximum_species = 10
|
||||
expand_coe = 1.5
|
||||
@@ -11,9 +11,9 @@ batch_size = 4
|
||||
|
||||
[population]
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 100
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 150
|
||||
pop_size = 1500
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from evox import algorithms, problems, pipelines
|
||||
from evox.monitors import StdSOMonitor
|
||||
|
||||
monitor = StdSOMonitor()
|
||||
|
||||
pso = algorithms.PSO(
|
||||
lb=jnp.full(shape=(2,), fill_value=-32),
|
||||
ub=jnp.full(shape=(2,), fill_value=32),
|
||||
pop_size=100,
|
||||
)
|
||||
|
||||
ackley = problems.classic.Ackley()
|
||||
|
||||
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 100 steps
|
||||
for i in range(100):
|
||||
state = pipeline.step(state)
|
||||
|
||||
print(monitor.get_min_fitness())
|
||||
@@ -1,28 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from jit_pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return np.array(fitnesses) # returns a list
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
print(nodes, cons)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,7 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from pipeline import Pipeline
|
||||
from algorithms.neat.pipeline import Pipeline
|
||||
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -21,7 +22,8 @@ def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
print(nodes, cons)
|
||||
print(nodes)
|
||||
print(cons)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from algorithms.neat import create_forward, topological_sort, \
|
||||
unflatten_connections, create_next_generation_then_speciate
|
||||
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
"""
|
||||
Creates and compiles functions used in the NEAT pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, jit_config):
|
||||
self.config = config
|
||||
self.jit_config = jit_config
|
||||
|
||||
self.func_dict = {}
|
||||
self.function_info = {}
|
||||
|
||||
# (inputs_nums, ) -> (outputs_nums, )
|
||||
forward = create_forward(config) # input size (inputs_nums, )
|
||||
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
self.function_info = {
|
||||
"pop_unflatten_connections": {
|
||||
'func': vmap(unflatten_connections),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
"pop_topological_sort": {
|
||||
'func': vmap(topological_sort),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
|
||||
]
|
||||
},
|
||||
|
||||
"batch_forward": {
|
||||
'func': batch_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('N',), 'type': np.int32},
|
||||
{'shape': ('N', 5), 'type': np.float32},
|
||||
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
"pop_batch_forward": {
|
||||
'func': pop_batch_forward,
|
||||
'lowers': [
|
||||
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
'common_forward': {
|
||||
'func': common_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
'create_next_generation_then_speciate': {
|
||||
'func': create_next_generation_then_speciate,
|
||||
'lowers': [
|
||||
{'shape': (2,), 'type': np.uint32}, # rand_key
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||
{'shape': ('P',), 'type': np.int32}, # winner
|
||||
{'shape': ('P',), 'type': np.int32}, # loser
|
||||
{'shape': ('P',), 'type': bool}, # elite_mask
|
||||
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||
{'shape': ('S',), 'type': np.int32}, # species_keys
|
||||
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||
"jit_config"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get(self, name, symbols):
|
||||
if (name, hash_symbols(symbols)) not in self.func_dict:
|
||||
self.compile(name, symbols)
|
||||
return self.func_dict[name, hash_symbols(symbols)]
|
||||
|
||||
def compile(self, name, symbols):
|
||||
# prepare function prototype
|
||||
func = self.function_info[name]['func']
|
||||
|
||||
# prepare lower operands
|
||||
lowers_operands = []
|
||||
for lower in self.function_info[name]['lowers']:
|
||||
if isinstance(lower, dict):
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
shape[i] = symbols[s]
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
|
||||
elif lower == "jit_config":
|
||||
lowers_operands.append(self.jit_config)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid lower operand")
|
||||
|
||||
# compile
|
||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||
|
||||
# save for reuse
|
||||
self.func_dict[name, hash_symbols(symbols)] = compiled_func
|
||||
189
pipeline.py
189
pipeline.py
@@ -1,189 +0,0 @@
|
||||
import time
|
||||
from typing import Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
|
||||
from configs import Configer
|
||||
from function_factory import FunctionFactory
|
||||
from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, function_factory=None, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||
|
||||
self.symbols = {
|
||||
'P': self.config['pop_size'],
|
||||
'N': self.config['init_maximum_nodes'],
|
||||
'C': self.config['init_maximum_connections'],
|
||||
'S': self.config['init_maximum_species'],
|
||||
}
|
||||
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.species_controller = SpeciesController(self.config)
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
print(self.config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Creates a function that receives a genome and returns a forward function.
|
||||
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
|
||||
|
||||
single:
|
||||
Create pop_size number of forward functions.
|
||||
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
||||
e.g. RL task
|
||||
|
||||
pop:
|
||||
Create a single forward function, which use only once calculation for the population.
|
||||
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
|
||||
common:
|
||||
Special case of pop. The population has the same inputs.
|
||||
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
e.g. numerical regression; Hyper-NEAT
|
||||
|
||||
"""
|
||||
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
|
||||
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
|
||||
|
||||
if self.config['forward_way'] == 'single':
|
||||
forward_funcs = []
|
||||
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
|
||||
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
|
||||
forward_funcs.append(func)
|
||||
return forward_funcs
|
||||
|
||||
elif self.config['forward_way'] == 'pop':
|
||||
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
elif self.config['forward_way'] == 'common':
|
||||
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \
|
||||
self.species_controller.ask(fitnesses, self.generation, self.symbols)
|
||||
|
||||
# node keys to be used in the mutation process
|
||||
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||
|
||||
# create the next generation and then speciate the population
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
self.get_func('create_next_generation_then_speciate') \
|
||||
(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||
center_cons, species_keys, species_key_start, self.jit_config)
|
||||
|
||||
# carry data to cpu
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||
|
||||
self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation)
|
||||
|
||||
# expand the population if needed
|
||||
self.expand()
|
||||
|
||||
# update randkey
|
||||
self.randkey = jax.random.split(self.randkey)[0]
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
when the maximum node number >= N or the maximum connection number of >= C
|
||||
the population will expand
|
||||
"""
|
||||
|
||||
# analysis nodes
|
||||
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
|
||||
# analysis connections
|
||||
pop_con_keys = self.pop_cons[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
|
||||
# expand if needed
|
||||
if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']:
|
||||
if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']:
|
||||
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||
print(f"pre node expand to {self.symbols['N']}!")
|
||||
|
||||
if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']:
|
||||
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||
print(f"pre connection expand to {self.symbols['C']}!")
|
||||
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C'])
|
||||
|
||||
def get_func(self, name):
|
||||
return self.function_factory.get(name, self.symbols)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
forward_func = self.ask()
|
||||
|
||||
tic = time.time()
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"Callable is needed here😅😅😅 A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
Reference in New Issue
Block a user