320 lines
12 KiB
Python
320 lines
12 KiB
Python
import jax
|
|
from jax import numpy as jnp, vmap
|
|
|
|
from core import Gene, Genome, State
|
|
from utils import rank_elements, fetch_first
|
|
from .distance import distance
|
|
from .species_info import SpeciesInfo
|
|
|
|
|
|
def update_species(state, randkey, fitness):
|
|
# update the fitness of each species
|
|
species_fitness = update_species_fitness(state, fitness)
|
|
|
|
# stagnation species
|
|
state, species_fitness = stagnation(state, species_fitness)
|
|
|
|
# sort species_info by their fitness. (push nan to the end)
|
|
sort_indices = jnp.argsort(species_fitness)[::-1]
|
|
|
|
state = state.update(
|
|
species_info=state.species_info[sort_indices],
|
|
center_genomes=state.center_genomes[sort_indices],
|
|
)
|
|
|
|
# decide the number of members of each species by their fitness
|
|
spawn_number = cal_spawn_numbers(state)
|
|
|
|
# crossover info
|
|
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
|
|
|
|
return state, winner, loser, elite_mask
|
|
|
|
|
|
def update_species_fitness(state, fitness):
|
|
"""
|
|
obtain the fitness of the species by the fitness of each individual.
|
|
use max criterion.
|
|
"""
|
|
|
|
def aux_func(idx):
|
|
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
|
|
f = jnp.max(s_fitness)
|
|
return f
|
|
|
|
return vmap(aux_func)(jnp.arange(state.species_info.size()))
|
|
|
|
|
|
def stagnation(state, species_fitness):
|
|
"""
|
|
stagnation species.
|
|
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
|
elitism species never stagnation
|
|
"""
|
|
|
|
def aux_func(idx):
|
|
s_fitness = species_fitness[idx]
|
|
sk, bf, li, _ = state.species_info.get(idx)
|
|
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
|
|
li = jnp.where(s_fitness > bf, state.generation, li)
|
|
bf = jnp.where(s_fitness > bf, s_fitness, bf)
|
|
|
|
return st, sk, bf, li
|
|
|
|
spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
|
|
|
|
# elite species will not be stagnation
|
|
species_rank = rank_elements(species_fitness)
|
|
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
|
|
|
|
# set stagnation species to nan
|
|
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
|
|
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
|
|
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
|
|
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
|
|
|
|
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
|
|
|
|
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
|
|
|
|
# TODO: Simplify the coded
|
|
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
|
|
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
|
|
|
|
state = state.update(
|
|
species_info=species_info,
|
|
center_genomes=Genome(center_nodes, center_conns)
|
|
)
|
|
|
|
return state, species_fitness
|
|
|
|
|
|
def cal_spawn_numbers(state):
|
|
"""
|
|
decide the number of members of each species by their fitness rank.
|
|
the species with higher fitness will have more members
|
|
Linear ranking selection
|
|
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
|
"""
|
|
|
|
species_keys = state.species_info.species_keys
|
|
|
|
is_species_valid = ~jnp.isnan(species_keys)
|
|
valid_species_num = jnp.sum(is_species_valid)
|
|
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
|
|
|
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
|
|
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
|
|
|
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
|
|
|
# Avoid too much variation of numbers in a species
|
|
previous_size = state.species_info.member_count
|
|
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
|
|
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
|
|
spawn_number = spawn_number.astype(jnp.int32)
|
|
|
|
# must control the sum of spawn_number to be equal to pop_size
|
|
error = state.P - jnp.sum(spawn_number)
|
|
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
|
|
|
return spawn_number
|
|
|
|
|
|
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
|
species_size = state.species_info.size()
|
|
pop_size = fitness.shape[0]
|
|
s_idx = jnp.arange(species_size)
|
|
p_idx = jnp.arange(pop_size)
|
|
|
|
# def aux_func(key, idx):
|
|
def aux_func(key, idx):
|
|
members = state.idx2species == state.species_info.species_keys[idx]
|
|
members_num = jnp.sum(members)
|
|
|
|
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
|
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
|
|
|
elite_size = state.genome_elitism
|
|
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
|
|
|
|
select_pro = (p_idx < survive_size) / survive_size
|
|
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
|
|
|
# elite
|
|
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
|
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
|
elite = jnp.where(p_idx < elite_size, True, False)
|
|
return fa, ma, elite
|
|
|
|
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
|
|
|
spawn_number_cum = jnp.cumsum(spawn_number)
|
|
|
|
def aux_func(idx):
|
|
loc = jnp.argmax(idx < spawn_number_cum)
|
|
|
|
# elite genomes are at the beginning of the species
|
|
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
|
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
|
|
|
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
|
|
|
is_part1_win = fitness[part1] >= fitness[part2]
|
|
winner = jnp.where(is_part1_win, part1, part2)
|
|
loser = jnp.where(is_part1_win, part2, part1)
|
|
|
|
return winner, loser, elite_mask
|
|
|
|
|
|
def speciate(gene: Gene, state: State):
|
|
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
|
|
|
|
# prepare distance functions
|
|
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
|
|
|
|
# idx to specie key
|
|
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
|
|
|
# the distance between genomes to its center genomes
|
|
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
|
|
|
# step 1: find new centers
|
|
def cond_func(carry):
|
|
i, i2s, cgs, o2c = carry
|
|
|
|
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
|
|
|
|
def body_func(carry):
|
|
i, i2s, cgs, o2c = carry
|
|
|
|
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
|
|
|
# find the closest one
|
|
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
|
|
|
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
|
cgs = cgs.set(i, state.pop_genomes[closest_idx])
|
|
|
|
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
|
o2c = o2c.at[closest_idx].set(0)
|
|
|
|
return i + 1, i2s, cgs, o2c
|
|
|
|
_, idx2species, center_genomes, o2c_distances = \
|
|
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
|
|
|
|
state = state.update(
|
|
idx2species=idx2species,
|
|
center_genomes=center_genomes,
|
|
)
|
|
|
|
# part 2: assign members to each species
|
|
def cond_func(carry):
|
|
i, i2s, cgs, sk, o2c, nsk = carry
|
|
|
|
current_species_existed = ~jnp.isnan(sk[i])
|
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
|
not_reach_species_upper_bounds = i < species_size
|
|
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
|
|
|
def body_func(carry):
|
|
i, i2s, cgs, sk, o2c, nsk = carry
|
|
|
|
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
|
|
jnp.isnan(sk[i]), # whether the current species is existing or not
|
|
create_new_species, # if not existing, create a new specie
|
|
update_exist_specie, # if existing, update the specie
|
|
(i, i2s, cgs, sk, o2c, nsk)
|
|
)
|
|
|
|
return i + 1, i2s, cgs, sk, o2c, nsk
|
|
|
|
def create_new_species(carry):
|
|
i, i2s, cgs, sk, o2c, nsk = carry
|
|
|
|
# pick the first one who has not been assigned to any species
|
|
idx = fetch_first(jnp.isnan(i2s))
|
|
|
|
# assign it to the new species
|
|
# [key, best score, last update generation, member_count]
|
|
sk = sk.at[i].set(nsk)
|
|
i2s = i2s.at[idx].set(nsk)
|
|
o2c = o2c.at[idx].set(0)
|
|
|
|
# update center genomes
|
|
cgs = cgs.set(i, state.pop_genomes[idx])
|
|
|
|
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
|
|
|
# when a new species is created, it needs to be updated, thus do not change i
|
|
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
|
|
|
|
def update_exist_specie(carry):
|
|
i, i2s, cgs, sk, o2c, nsk = carry
|
|
|
|
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
|
|
|
# turn to next species
|
|
return i + 1, i2s, cgs, sk, o2c, nsk
|
|
|
|
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
|
|
# distance between such center genome and ppo genomes
|
|
|
|
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
|
close_enough_mask = o2p_distance < state.compatibility_threshold
|
|
|
|
# when a genome is not assigned or the distance between its current center is bigger than this center
|
|
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
|
# jax.debug.print("{}", o2p_distance)
|
|
mask = close_enough_mask & cacheable_mask
|
|
|
|
# update species info
|
|
i2s = jnp.where(mask, sk[i], i2s)
|
|
|
|
# update distance between centers
|
|
o2c = jnp.where(mask, o2p_distance, o2c)
|
|
|
|
return i2s, o2c
|
|
|
|
# update idx2species
|
|
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
|
|
cond_func,
|
|
body_func,
|
|
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
|
|
state.next_species_key)
|
|
)
|
|
|
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
|
# this condition can only happen when the number of species is reached species upper bounds
|
|
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
|
|
|
# complete info of species which is created in this generation
|
|
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
|
|
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
|
|
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
|
|
|
|
# update members count
|
|
def count_members(idx):
|
|
key = species_keys[idx]
|
|
count = jnp.sum(idx2species == key, dtype=jnp.float32)
|
|
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
|
|
|
return count
|
|
|
|
member_count = vmap(count_members)(jnp.arange(species_size))
|
|
|
|
return state.update(
|
|
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
|
|
idx2species=idx2species,
|
|
center_genomes=center_genomes,
|
|
next_species_key=next_species_key
|
|
)
|
|
|
|
|
|
def argmin_with_mask(arr, mask):
|
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
|
min_idx = jnp.argmin(masked_arr)
|
|
return min_idx
|