368 lines
15 KiB
Python
368 lines
15 KiB
Python
from typing import Type
|
|
|
|
import jax
|
|
from jax import numpy as jnp, vmap
|
|
|
|
from .utils import rank_elements, fetch_first
|
|
from .genome import create_mutate, create_distance, crossover
|
|
from .gene import BaseGene
|
|
|
|
def create_tell(config, gene_type: Type[BaseGene]):
|
|
|
|
mutate = create_mutate(config, gene_type)
|
|
distance = create_distance(config, gene_type)
|
|
|
|
def update_species(state, randkey, fitness):
|
|
# update the fitness of each species
|
|
species_fitness = update_species_fitness(state, fitness)
|
|
|
|
# stagnation species
|
|
state, species_fitness = stagnation(state, species_fitness)
|
|
|
|
# sort species_info by their fitness. (push nan to the end)
|
|
sort_indices = jnp.argsort(species_fitness)[::-1]
|
|
|
|
state = state.update(
|
|
species_info=state.species_info[sort_indices],
|
|
center_nodes=state.center_nodes[sort_indices],
|
|
center_conns=state.center_conns[sort_indices],
|
|
)
|
|
|
|
# decide the number of members of each species by their fitness
|
|
spawn_number = cal_spawn_numbers(state)
|
|
|
|
# crossover info
|
|
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
|
|
|
|
return state, winner, loser, elite_mask
|
|
|
|
|
|
def update_species_fitness(state, fitness):
|
|
"""
|
|
obtain the fitness of the species by the fitness of each individual.
|
|
use max criterion.
|
|
"""
|
|
|
|
def aux_func(idx):
|
|
species_key = state.species_info[idx, 0]
|
|
s_fitness = jnp.where(state.idx2species == species_key, fitness, -jnp.inf)
|
|
f = jnp.max(s_fitness)
|
|
return f
|
|
|
|
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
|
|
|
|
|
|
def stagnation(state, species_fitness):
|
|
"""
|
|
stagnation species.
|
|
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
|
elitism species never stagnation
|
|
"""
|
|
|
|
def aux_func(idx):
|
|
s_fitness = species_fitness[idx]
|
|
species_key, best_score, last_update, members_count = state.species_info[idx]
|
|
st = (s_fitness <= best_score) & (state.generation - last_update > state.max_stagnation)
|
|
last_update = jnp.where(s_fitness > best_score, state.generation, last_update)
|
|
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
|
|
# stagnation condition
|
|
return st, jnp.array([species_key, best_score, last_update, members_count])
|
|
|
|
spe_st, species_info = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
|
|
|
|
# elite species will not be stagnation
|
|
species_rank = rank_elements(species_fitness)
|
|
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
|
|
|
|
# set stagnation species to nan
|
|
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
|
|
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_nodes)
|
|
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_conns)
|
|
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
|
|
|
|
state = state.update(
|
|
species_info=species_info,
|
|
center_nodes=center_nodes,
|
|
center_conns=center_conns,
|
|
)
|
|
|
|
return state, species_fitness
|
|
|
|
|
|
def cal_spawn_numbers(state):
|
|
"""
|
|
decide the number of members of each species by their fitness rank.
|
|
the species with higher fitness will have more members
|
|
Linear ranking selection
|
|
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
|
"""
|
|
|
|
is_species_valid = ~jnp.isnan(state.species_info[:, 0])
|
|
valid_species_num = jnp.sum(is_species_valid)
|
|
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
|
|
|
rank_score = valid_species_num - jnp.arange(state.species_info.shape[0]) # obtain [3, 2, 1]
|
|
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
|
|
|
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
|
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
|
|
|
# Avoid too much variation of numbers in a species
|
|
previous_size = state.species_info[:, 3].astype(jnp.int32)
|
|
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
|
|
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
|
|
spawn_number = spawn_number.astype(jnp.int32)
|
|
|
|
# spawn_number = target_spawn_number.astype(jnp.int32)
|
|
|
|
# must control the sum of spawn_number to be equal to pop_size
|
|
error = state.P - jnp.sum(spawn_number)
|
|
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
|
|
|
return spawn_number
|
|
|
|
|
|
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
|
species_size = state.species_info.shape[0]
|
|
pop_size = fitness.shape[0]
|
|
s_idx = jnp.arange(species_size)
|
|
p_idx = jnp.arange(pop_size)
|
|
|
|
# def aux_func(key, idx):
|
|
def aux_func(key, idx):
|
|
members = state.idx2species == state.species_info[idx, 0]
|
|
members_num = jnp.sum(members)
|
|
|
|
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
|
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
|
|
|
elite_size = state.genome_elitism
|
|
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
|
|
|
|
select_pro = (p_idx < survive_size) / survive_size
|
|
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
|
|
|
# elite
|
|
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
|
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
|
elite = jnp.where(p_idx < elite_size, True, False)
|
|
return fa, ma, elite
|
|
|
|
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
|
|
|
spawn_number_cum = jnp.cumsum(spawn_number)
|
|
|
|
def aux_func(idx):
|
|
loc = jnp.argmax(idx < spawn_number_cum)
|
|
|
|
# elite genomes are at the beginning of the species
|
|
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
|
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
|
|
|
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
|
|
|
is_part1_win = fitness[part1] >= fitness[part2]
|
|
winner = jnp.where(is_part1_win, part1, part2)
|
|
loser = jnp.where(is_part1_win, part2, part1)
|
|
|
|
return winner, loser, elite_mask
|
|
|
|
def create_next_generation(state, randkey, winner, loser, elite_mask):
|
|
# prepare random keys
|
|
pop_size = state.pop_nodes.shape[0]
|
|
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
|
|
|
k1, k2 = jax.random.split(randkey, 2)
|
|
crossover_rand_keys = jax.random.split(k1, pop_size)
|
|
mutate_rand_keys = jax.random.split(k2, pop_size)
|
|
|
|
# batch crossover
|
|
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] # winner pop nodes, winner pop connections
|
|
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] # loser pop nodes, loser pop connections
|
|
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
|
|
|
# batch mutation
|
|
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0, 0))
|
|
m_npn, m_npc = mutate_func(state, mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
|
|
|
# elitism don't mutate
|
|
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
|
pop_conns = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
|
|
|
# update next node key
|
|
all_nodes_keys = pop_nodes[:, :, 0]
|
|
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
|
next_node_key = max_node_key + 1
|
|
|
|
return state.update(
|
|
pop_nodes=pop_nodes,
|
|
pop_conns=pop_conns,
|
|
next_node_key=next_node_key,
|
|
)
|
|
|
|
def speciate(state):
|
|
pop_size, species_size = state.pop_nodes.shape[0], state.center_nodes.shape[0]
|
|
|
|
# prepare distance functions
|
|
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0, 0)) # one to population
|
|
|
|
# idx to specie key
|
|
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
|
|
|
# the distance between genomes to its center genomes
|
|
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
|
|
|
# step 1: find new centers
|
|
def cond_func(carry):
|
|
i, i2s, cn, cc, o2c = carry
|
|
species_key = state.species_info[i, 0]
|
|
# jax.debug.print("{}, {}", i, species_key)
|
|
return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing
|
|
|
|
def body_func(carry):
|
|
i, i2s, cn, cc, o2c = carry
|
|
distances = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns)
|
|
|
|
# find the closest one
|
|
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
|
# jax.debug.print("closest_idx: {}", closest_idx)
|
|
|
|
i2s = i2s.at[closest_idx].set(state.species_info[i, 0])
|
|
cn = cn.at[i].set(state.pop_nodes[closest_idx])
|
|
cc = cc.at[i].set(state.pop_conns[closest_idx])
|
|
|
|
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
|
o2c = o2c.at[closest_idx].set(0)
|
|
|
|
return i + 1, i2s, cn, cc, o2c
|
|
|
|
_, idx2specie, center_nodes, center_conns, o2c_distances = \
|
|
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
|
|
|
|
|
# part 2: assign members to each species
|
|
def cond_func(carry):
|
|
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
|
|
current_species_existed = ~jnp.isnan(si[i, 0])
|
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
|
not_reach_species_upper_bounds = i < species_size
|
|
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
|
|
|
def body_func(carry):
|
|
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_conns
|
|
|
|
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
|
|
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
|
create_new_species, # if not existing, create a new specie
|
|
update_exist_specie, # if existing, update the specie
|
|
(i, i2s, cn, cc, si, o2c, nsk)
|
|
)
|
|
|
|
return i + 1, i2s, scn, scc, si, o2c, nsk
|
|
|
|
def create_new_species(carry):
|
|
i, i2s, cn, cc, si, o2c, nsk = carry
|
|
|
|
# pick the first one who has not been assigned to any species
|
|
idx = fetch_first(jnp.isnan(i2s))
|
|
|
|
# assign it to the new species
|
|
# [key, best score, last update generation, members_count]
|
|
si = si.at[i].set(jnp.array([nsk, -jnp.inf, state.generation, 0]))
|
|
i2s = i2s.at[idx].set(nsk)
|
|
o2c = o2c.at[idx].set(0)
|
|
|
|
# update center genomes
|
|
cn = cn.at[i].set(state.pop_nodes[idx])
|
|
cc = cc.at[i].set(state.pop_conns[idx])
|
|
|
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
|
|
|
# when a new species is created, it needs to be updated, thus do not change i
|
|
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
|
|
|
|
def update_exist_specie(carry):
|
|
i, i2s, cn, cc, si, o2c, nsk = carry
|
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
|
|
|
# turn to next species
|
|
return i + 1, i2s, cn, cc, si, o2c, nsk
|
|
|
|
def speciate_by_threshold(carry):
|
|
i, i2s, cn, cc, si, o2c = carry
|
|
|
|
# distance between such center genome and ppo genomes
|
|
o2p_distance = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns)
|
|
close_enough_mask = o2p_distance < state.compatibility_threshold
|
|
|
|
# when a genome is not assigned or the distance between its current center is bigger than this center
|
|
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
|
# jax.debug.print("{}", o2p_distance)
|
|
mask = close_enough_mask & cacheable_mask
|
|
|
|
# update species info
|
|
i2s = jnp.where(mask, si[i, 0], i2s)
|
|
|
|
# update distance between centers
|
|
o2c = jnp.where(mask, o2p_distance, o2c)
|
|
|
|
return i2s, o2c
|
|
|
|
# update idx2specie
|
|
_, idx2specie, center_nodes, center_conns, species_info, _, next_species_key = jax.lax.while_loop(
|
|
cond_func,
|
|
body_func,
|
|
(0, idx2specie, center_nodes, center_conns, state.species_info, o2c_distances, state.next_species_key)
|
|
)
|
|
|
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
|
# this condition can only happen when the number of species is reached species upper bounds
|
|
idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie)
|
|
|
|
# update members count
|
|
def count_members(idx):
|
|
key = species_info[idx, 0]
|
|
count = jnp.sum(idx2specie == key)
|
|
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
|
return count
|
|
|
|
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
|
species_info = species_info.at[:, 3].set(species_member_counts)
|
|
|
|
return state.update(
|
|
idx2specie=idx2specie,
|
|
center_nodes=center_nodes,
|
|
center_conns=center_conns,
|
|
species_info=species_info,
|
|
next_species_key=next_species_key
|
|
)
|
|
|
|
def tell(state, fitness):
|
|
"""
|
|
Main update function in NEAT.
|
|
"""
|
|
|
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
|
|
|
state = state.update(
|
|
generation=state.generation + 1,
|
|
randkey=randkey
|
|
)
|
|
|
|
state, winner, loser, elite_mask = update_species(state, k1, fitness)
|
|
|
|
state = create_next_generation(state, k2, winner, loser, elite_mask)
|
|
|
|
state = speciate(state)
|
|
|
|
return state
|
|
|
|
|
|
return tell
|
|
|
|
|
|
def argmin_with_mask(arr, mask):
|
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
|
min_idx = jnp.argmin(masked_arr)
|
|
return min_idx |