The whole NEAT algorithm is written into functional programming.

This commit is contained in:
wls2002
2023-06-29 09:28:49 +08:00
parent 114ff2b0cc
commit d28cef1a87
16 changed files with 371 additions and 1102 deletions

View File

@@ -1,6 +1,5 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
from .operations import create_next_generation_then_speciate
from .species import SpeciesController
from .genome import create_forward_function, topological_sort, unflatten_connections
from .population import update_species, create_next_generation, speciate

View File

@@ -1,7 +1,6 @@
from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .forward import create_forward
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .genome import initialize_genomes, expand, expand_single
from .forward import create_forward_function

View File

@@ -5,7 +5,7 @@ from jax import jit, vmap
from .utils import I_INT
def create_forward(config):
def create_forward_function(config):
"""
meta method to create forward function
"""
@@ -83,4 +83,22 @@ def create_forward(config):
return vals[output_idx]
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
if config['forward_way'] == 'single':
return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)
elif config['forward_way'] == 'common':
return jit(common_forward)
return forward

View File

@@ -65,55 +65,6 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
return pop_nodes, pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes or connections.
:param nodes: (N, 5)
:param cons: (C, 4)
:param new_N:
:param new_C:
:return: (new_N, 5), (new_C, 4)
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the population to accommodate more nodes or connections.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
@jit
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
"""
Count how many nodes and connections are in the genome.
"""
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:

View File

@@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
if reverse is True, the rank is from small to large. default large to small
"""
if reverse:
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))

View File

@@ -1,160 +0,0 @@
from functools import partial
import jax
from jax import jit, numpy as jnp, vmap
from .genome.utils import rank_elements
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 3), float], the information of each species
[species_key, best_score, last_update]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
generation: int, current generation
jit_config: Dict, the configuration of jit functions
"""
# update the fitness of each species
species_fitness = update_species_fitness(species_info, idx2species, fitness)
# stagnation species
species_fitness, species_info, center_nodes, center_cons = \
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
species_info = species_info[sort_indices]
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
jax.debug.print("{}, {}", fitness, winner)
jax.debug.print("{}", fitness[winner])
return species_info, center_nodes, center_cons, winner, loser, elite_mask
def update_species_fitness(species_info, idx2species, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = species_info[idx, 0]
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update = species_info[idx]
# stagnation condition
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(st, jnp.nan, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
def cal_spawn_numbers(species_info, jit_config):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
def aux_func(key, idx):
members = idx2species == species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, jnp.nan)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = jit_config['genome_elitism']
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask

View File

@@ -1,166 +0,0 @@
"""
contains operations on the population: creating the next generation and population speciation.
"""
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
center_nodes, center_cons, species_keys, new_species_key_start,
jit_config):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
new_node_keys, jit_config)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
@jit
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
def find_new_centers(i, carry):
i2s, cn, cc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
return i2s, cn, cc
idx2specie, center_nodes, center_cons = \
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def body_func(carry):
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i2s, scn, scc, sk, ck = jax.lax.cond(
sk[i] == I_INT, # whether the current species is existing or not
create_new_specie, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, cn, cc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to the new species
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, sk, ck = carry
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck
def speciate_by_threshold(carry):
i, i2s, cn, cc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, center_nodes, center_cons, species_keys
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -3,13 +3,13 @@ from typing import Union, Callable
import numpy as np
import jax
from jax import jit, vmap
from configs import Configer
from function_factory import FunctionFactory
from algorithms.neat import initialize_genomes, expand, expand_single
from algorithms.neat import initialize_genomes
from algorithms.neat.jit_species import update_species
from algorithms.neat.operations import create_next_generation_then_speciate
from algorithms.neat.population import create_next_generation, speciate, update_species
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
class Pipeline:
@@ -17,30 +17,27 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, function_factory=None, seed=42):
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
self.symbols = {
'P': self.config['pop_size'],
'N': self.config['init_maximum_nodes'],
'C': self.config['init_maximum_connections'],
'S': self.config['init_maximum_species'],
}
self.P = config['pop_size']
self.N = config['init_maximum_nodes']
self.C = config['init_maximum_connections']
self.S = config['init_maximum_species']
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
self.species_info = np.full((self.symbols['S'], 3), np.nan)
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
self.species_info = np.full((self.S, 3), np.nan)
self.species_info[0, :] = 0, -np.inf, 0
self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32)
self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan)
self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan)
self.idx2species = np.zeros(self.P, dtype=np.float32)
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
self.center_cons = np.full((self.S, self.C, 4), np.nan)
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
@@ -49,7 +46,10 @@ class Pipeline:
self.generation_timestamp = time.time()
self.evaluate_time = 0
print(self.config)
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
self.pop_topological_sort = jit(vmap(topological_sort))
self.forward = create_forward_function(config)
def ask(self):
"""
@@ -71,52 +71,28 @@ class Pipeline:
e.g. numerical regression; Hyper-NEAT
"""
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
u_pop_cons = self.pop_unflatten_connections(self.pop_nodes, self.pop_cons)
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
if self.config['forward_way'] == 'single':
forward_funcs = []
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
forward_funcs.append(func)
return forward_funcs
elif self.config['forward_way'] == 'pop':
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
elif self.config['forward_way'] == 'common':
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
else:
raise NotImplementedError
# only common mode is supported currently
assert self.config['forward_way'] == 'common'
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
def tell(self, fitnesses):
self.generation += 1
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes,
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
# node keys to be used in the mutation process
new_node_keys = np.arange(self.generation * self.config['pop_size'],
self.generation * self.config['pop_size'] + self.config['pop_size'])
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
# create the next generation and then speciate the population
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
center_cons, species_keys, species_key_start, self.jit_config)
# carry data to cpu
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
# update randkey
self.randkey = jax.random.split(self.randkey)[0]
def get_func(self, name):
return self.function_factory.get(name, self.symbols)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):

View File

@@ -0,0 +1,307 @@
"""
contains operations on the population: creating the next generation and population speciation.
"""
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first, rank_elements
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 3), float], the information of each species
[species_key, best_score, last_update]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
generation: int, current generation
jit_config: Dict, the configuration of jit functions
"""
# update the fitness of each species
species_fitness = update_species_fitness(species_info, idx2species, fitness)
# stagnation species
species_fitness, species_info, center_nodes, center_cons = \
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
species_info = species_info[sort_indices]
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
return species_info, center_nodes, center_cons, winner, loser, elite_mask
def update_species_fitness(species_info, idx2species, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = species_info[idx, 0]
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update = species_info[idx]
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
last_update = jnp.where(s_fitness > best_score, generation, last_update)
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
# stagnation condition
return st, jnp.array([species_key, best_score, last_update])
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < jit_config['species_elitism'], False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(spe_st, jnp.nan, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
def cal_spawn_numbers(species_info, jit_config):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = idx2species == species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = jit_config['genome_elitism']
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
# fas, mas, elites = jax.lax.max(aux_func, (jax.random.split(randkey, species_size), s_idx))
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + generation * pop_size
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
@jit
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
)
# idx to specie key
idx2specie = jnp.full((pop_size,), jnp.nan) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
def find_new_centers(i, carry):
i2s, cn, cc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=jnp.isnan(i2s))
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(~jnp.isnan(species_info[i, 0]), idx, I_INT)
i = jnp.where(~jnp.isnan(species_info[i, 0]), i, I_INT)
i2s = i2s.at[idx].set(species_info[i, 0])
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
return i2s, cn, cc
idx2specie, center_nodes, center_cons = \
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, ck = carry # si is short for species_info, ck is short for current key
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def body_func(carry):
i, i2s, cn, cc, si, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i2s, scn, scc, si, ck = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_specie, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, ck)
)
return i + 1, i2s, scn, scc, si, ck
def create_new_specie(carry):
i, i2s, cn, cc, si, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation]
i2s = i2s.at[idx].set(ck)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
return i2s, cn, cc, si, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, si, ck = carry
i2s = speciate_by_threshold((i, i2s, cn, cc, si))
return i2s, cn, cc, si, ck
def speciate_by_threshold(carry):
i, i2s, cn, cc, si = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & jnp.isnan(i2s), si[i, 0], i2s)
return i2s
species_keys = species_info[:, 0]
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_info, _ = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_info, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
return idx2specie, center_nodes, center_cons, species_info
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,283 +0,0 @@
"""
Species Controller in NEAT.
The code are modified from neat-python.
See
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
"""
from typing import List, Tuple, Dict
import numpy as np
from numpy.typing import NDArray
from .genome.utils import I_INT
class Species(object):
def __init__(self, key, generation):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None
self.member_fitnesses = None
self.adjusted_fitness = None
self.fitness_history: List[float] = []
def update(self, representative, members):
self.representative = representative
self.members = members
def get_fitnesses(self, fitnesses):
return fitnesses[self.members]
class SpeciesController:
"""
A class to control the species
"""
def __init__(self, config):
self.config = config
self.species_elitism = self.config['species_elitism']
self.pop_size = self.config['pop_size']
self.max_stagnation = self.config['max_stagnation']
self.min_species_size = self.config['min_species_size']
self.genome_elitism = self.config['genome_elitism']
self.survival_threshold = self.config['survival_threshold']
self.species: Dict[int, Species] = {} # species_id -> species
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
"""
speciate for the first generation
:param pop_connections:
:param pop_nodes:
:return:
"""
pop_size = pop_nodes.shape[0]
species_id = 0 # the first species
s = Species(species_id, 0)
members = np.array(list(range(pop_size)))
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
def __update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
:return:
"""
for sid, s in self.species.items():
s.member_fitnesses = s.get_fitnesses(fitnesses)
# use the max score to represent the fitness of the species
s.fitness = np.max(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def __stagnation(self, generation):
"""
:param generation:
:return: whether the species is stagnated
"""
species_data = []
for sid, s in self.species.items():
if s.fitness_history:
prev_fitness = max(s.fitness_history)
else:
prev_fitness = float('-inf')
if s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
# Sort in descending fitness order.
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
result = []
for idx, (sid, s) in enumerate(species_data):
if idx < self.species_elitism: # elitism species never stagnate!
is_stagnant = False
else:
stagnant_time = generation - s.last_improved
is_stagnant = stagnant_time > self.max_stagnation
result.append((sid, s, is_stagnant))
return result
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
:param fitnesses:
:param generation:
:return: crossover_pair for next generation.
# int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover
"""
# Filter out stagnated species, collect the set of non-stagnated
# species members, and compute their average adjusted fitness.
# The average adjusted fitness scheme (normalized to the interval
# [0, 1]) allows the use of negative fitness values without
# interfering with the shared fitness scheme.
min_fitness = np.inf
max_fitness = -np.inf
remaining_species = []
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
if not stagnant:
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
remaining_species.append(stag_s)
# No species left.
assert remaining_species
# TODO: Too complex!
# Compute each species' member size in the next generation.
# Do not allow the fitness range to be zero, as we divide by it below.
# TODO: The ``1.0`` below is rather arbitrary, and should be configurable.
fitness_range = max(1.0, max_fitness - min_fitness)
for afs in remaining_species:
# Compute adjusted fitness.
msf = afs.fitness
af = (msf - min_fitness) / fitness_range # make adjusted fitness in [0, 1]
afs.adjusted_fitness = af
adjusted_fitnesses = [s.adjusted_fitness for s in remaining_species]
previous_sizes = [len(s.members) for s in remaining_species]
min_species_size = max(self.min_species_size, self.genome_elitism)
spawn_amounts = compute_spawn(adjusted_fitnesses, previous_sizes, self.pop_size, min_species_size)
assert sum(spawn_amounts) == self.pop_size
# generate new population and speciate
self.species = {}
# int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover
part1, part2, elite_mask = [], [], []
for spawn, s in zip(spawn_amounts, remaining_species):
assert spawn >= self.genome_elitism
# retain remain species to next generation
old_members, member_fitnesses = s.members, s.member_fitnesses
s.members = []
self.species[s.key] = s
# add elitism genomes to next generation
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses)
if self.genome_elitism > 0:
for m in sorted_members[:self.genome_elitism]:
part1.append(m)
part2.append(m)
elite_mask.append(True)
spawn -= 1
if spawn <= 0:
continue
# add genome to be crossover to next generation
repro_cutoff = int(np.ceil(self.survival_threshold * len(sorted_members)))
repro_cutoff = max(repro_cutoff, 2)
# only use good genomes to crossover
sorted_members = sorted_members[:repro_cutoff]
# TODO: Genome with higher fitness should be more likely to be selected?
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
part1.extend(sorted_members[list_idx1])
part2.extend(sorted_members[list_idx2])
elite_mask.extend([False] * spawn)
part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2]
is_part1_win = part1_fitness >= part2_fitness
winner_part = np.where(is_part1_win, part1, part2)
loser_part = np.where(is_part1_win, part2, part1)
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
# the new specie created in this generation
s = Species(key, generation)
self.species[key] = s
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
def ask(self, fitnesses, generation, symbols):
self.__update_species_fitnesses(fitnesses)
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
species_keys = np.full((symbols['S'], ), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
center_nodes[idx], center_cons[idx] = specie.representative
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
Code from neat-python, the only modification is to fix the population size for each generation.
Compute the proper number of offspring per species (proportional to fitness).
"""
af_sum = sum(adjusted_fitness)
spawn_amounts = []
for af, ps in zip(adjusted_fitness, previous_sizes):
if af_sum > 0:
s = max(min_species_size, af / af_sum * pop_size)
else:
s = min_species_size
d = (s - ps) * 0.5
c = int(round(d))
spawn = ps
if abs(c) > 0:
spawn += c
elif d > 0:
spawn += 1
elif d < 0:
spawn -= 1
spawn_amounts.append(spawn)
# Normalize the spawn amounts so that the next generation is roughly
# the population size requested by the user.
total_spawn = sum(spawn_amounts)
norm = pop_size / total_spawn
spawn_amounts = [max(min_species_size, int(round(n * norm))) for n in spawn_amounts]
# for batch parallelization, pop size must be a fixed value.
total_amounts = sum(spawn_amounts)
spawn_amounts[0] += pop_size - total_amounts
assert sum(spawn_amounts) == pop_size, "Population size is not stable."
return spawn_amounts
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1]
return members[sorted_idx], fitnesses[sorted_idx]

View File

@@ -1,7 +1,7 @@
[basic]
num_inputs = 2
num_outputs = 1
init_maximum_nodes = 200
init_maximum_nodes = 50
init_maximum_connections = 200
init_maximum_species = 10
expand_coe = 1.5
@@ -11,9 +11,9 @@ batch_size = 4
[population]
fitness_threshold = 100000
generation_limit = 100
generation_limit = 1000
fitness_criterion = "max"
pop_size = 150
pop_size = 1500
[genome]
compatibility_disjoint = 1.0

View File

@@ -1,26 +0,0 @@
import jax
from jax import numpy as jnp
from evox import algorithms, problems, pipelines
from evox.monitors import StdSOMonitor
monitor = StdSOMonitor()
pso = algorithms.PSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=100,
)
ackley = problems.classic.Ackley()
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
key = jax.random.PRNGKey(42)
state = pipeline.init(key)
# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())

View File

@@ -1,28 +0,0 @@
import numpy as np
from configs import Configer
from jit_pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return np.array(fitnesses) # returns a list
def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
if __name__ == '__main__':
main()

View File

@@ -1,7 +1,8 @@
import numpy as np
from configs import Configer
from pipeline import Pipeline
from algorithms.neat.pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,7 +22,8 @@ def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
print(nodes)
print(cons)
if __name__ == '__main__':

View File

@@ -1,132 +0,0 @@
import numpy as np
from jax import jit, vmap
from algorithms.neat import create_forward, topological_sort, \
unflatten_connections, create_next_generation_then_speciate
def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
class FunctionFactory:
"""
Creates and compiles functions used in the NEAT pipeline.
"""
def __init__(self, config, jit_config):
self.config = config
self.jit_config = jit_config
self.func_dict = {}
self.function_info = {}
# (inputs_nums, ) -> (outputs_nums, )
forward = create_forward(config) # input size (inputs_nums, )
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
self.function_info = {
"pop_unflatten_connections": {
'func': vmap(unflatten_connections),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 'C', 4), 'type': np.float32}
]
},
"pop_topological_sort": {
'func': vmap(topological_sort),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
]
},
"batch_forward": {
'func': batch_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('N',), 'type': np.int32},
{'shape': ('N', 5), 'type': np.float32},
{'shape': (2, 'N', 'N'), 'type': np.float32}
]
},
"pop_batch_forward": {
'func': pop_batch_forward,
'lowers': [
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'common_forward': {
'func': common_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'create_next_generation_then_speciate': {
'func': create_next_generation_then_speciate,
'lowers': [
{'shape': (2,), 'type': np.uint32}, # rand_key
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
{'shape': ('P',), 'type': np.int32}, # winner
{'shape': ('P',), 'type': np.int32}, # loser
{'shape': ('P',), 'type': bool}, # elite_mask
{'shape': ('P',), 'type': np.int32}, # new_node_keys
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
{'shape': ('S',), 'type': np.int32}, # species_keys
{'shape': (), 'type': np.int32}, # new_species_key_start
"jit_config"
]
}
}
def get(self, name, symbols):
if (name, hash_symbols(symbols)) not in self.func_dict:
self.compile(name, symbols)
return self.func_dict[name, hash_symbols(symbols)]
def compile(self, name, symbols):
# prepare function prototype
func = self.function_info[name]['func']
# prepare lower operands
lowers_operands = []
for lower in self.function_info[name]['lowers']:
if isinstance(lower, dict):
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
elif lower == "jit_config":
lowers_operands.append(self.jit_config)
else:
raise ValueError("Invalid lower operand")
# compile
compiled_func = jit(func).lower(*lowers_operands).compile()
# save for reuse
self.func_dict[name, hash_symbols(symbols)] = compiled_func

View File

@@ -1,189 +0,0 @@
import time
from typing import Union, Callable
import numpy as np
import jax
from configs import Configer
from function_factory import FunctionFactory
from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, function_factory=None, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
self.symbols = {
'P': self.config['pop_size'],
'N': self.config['init_maximum_nodes'],
'C': self.config['init_maximum_connections'],
'S': self.config['init_maximum_species'],
}
self.generation = 0
self.best_genome = None
self.species_controller = SpeciesController(self.config)
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
print(self.config)
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
single:
Create pop_size number of forward functions.
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
e.g. RL task
pop:
Create a single forward function, which use only once calculation for the population.
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
common:
Special case of pop. The population has the same inputs.
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
e.g. numerical regression; Hyper-NEAT
"""
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
if self.config['forward_way'] == 'single':
forward_funcs = []
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
forward_funcs.append(func)
return forward_funcs
elif self.config['forward_way'] == 'pop':
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
elif self.config['forward_way'] == 'common':
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
else:
raise NotImplementedError
def tell(self, fitnesses):
self.generation += 1
winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \
self.species_controller.ask(fitnesses, self.generation, self.symbols)
# node keys to be used in the mutation process
new_node_keys = np.arange(self.generation * self.config['pop_size'],
self.generation * self.config['pop_size'] + self.config['pop_size'])
# create the next generation and then speciate the population
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
self.get_func('create_next_generation_then_speciate') \
(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
center_cons, species_keys, species_key_start, self.jit_config)
# carry data to cpu
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation)
# expand the population if needed
self.expand()
# update randkey
self.randkey = jax.random.split(self.randkey)[0]
def expand(self):
"""
Expand the population if needed.
when the maximum node number >= N or the maximum connection number of >= C
the population will expand
"""
# analysis nodes
pop_node_keys = self.pop_nodes[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes)
# analysis connections
pop_con_keys = self.pop_cons[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
# expand if needed
if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']:
if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']:
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
print(f"pre node expand to {self.symbols['N']}!")
if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']:
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
print(f"pre connection expand to {self.symbols['C']}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C'])
def get_func(self, name):
return self.function_factory.get(name, self.symbols)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"Callable is needed here😅😅😅 A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")