new architecture

This commit is contained in:
wls2002
2024-01-27 00:52:39 +08:00
parent 4efe9a53c1
commit aac41a089d
65 changed files with 1651 additions and 1783 deletions

View File

@@ -1,2 +1,2 @@
from .species_info import SpeciesInfo
from .operations import update_species, speciate
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -0,0 +1,14 @@
from utils import State
class BaseSpecies:
def setup(self, randkey):
raise NotImplementedError
def ask(self, state: State):
raise NotImplementedError
def update_species(self, state, fitness, generation):
raise NotImplementedError
def speciate(self, state, generation):
raise NotImplementedError

View File

@@ -0,0 +1,514 @@
import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
class DefaultSpecies:
def __init__(self,
genome: BaseGenome,
pop_size,
species_size,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.5
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
self.genome_elitism = genome_elitism
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved
member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# nodes for each center genome of each species
center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan)
# connections for each center genome of each species
center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan)
species_keys = species_keys.at[0].set(0)
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
return State(
randkey=randkey,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=1, # 0 is reserved for the first species
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness, generation):
# update the fitness of each species
species_fitness = self.update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = self.stagnation(state, generation, species_fitness)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = self.cal_spawn_numbers(state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
val = jnp.max(s_fitness)
return val
return jax.vmap(aux_func)(self.species_arange)
def stagnation(self, state, generation, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
generation: the current generation
"""
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)
# update last_improved and best_fitness
li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
def update_func(idx):
return jax.lax.cond(
spe_st[idx],
lambda: (
jnp.nan, # species_key
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
-jnp.inf, # species_fitness
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_fitness[idx],
center_nodes[idx],
center_conns[idx]
) # not stagnation species
)
(
species_keys,
best_fitness,
last_improved,
member_count,
species_fitness,
center_nodes,
center_conns
) = (
jax.vmap(update_func)(self.species_arange))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_nodes=center_nodes,
center_conns=center_conns,
), species_fitness
def cal_spawn_numbers(self, state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member
# Avoid too much variation of numbers for a species
previous_size = state.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def create_crossover_pair(self, state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(self, state, generation):
# prepare distance functions
o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population
# idx to specie key
idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, o2c_distances
i, i2s, cns, ccs, o2c = carry
return (
(i < self.species_size) &
(~jnp.isnan(state.species_keys[i]))
) # current species is existing
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cns = cns.set(i, state.pop_nodes[closest_idx])
ccs = ccs.set(i, state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cns, ccs, o2c
_, idx2species, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances))
state = state.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
)
# part 2: assign members to each species
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
i, i2s, cns, ccs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < self.species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cns, ccs, sk, o2c, nsk)
)
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk) # nsk -> next species key
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.set(i, state.pop_nodes[idx])
ccs = ccs.set(i, state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
# turn to next species
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
close_enough_mask = o2p_distance < self.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
mask = close_enough_mask & catchable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, generation, state.last_improved)
# update members count
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda _: jnp.nan, # nan
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=next_species_key
)
def distance(self, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
def node_distance(self, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def conn_distance(self, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome):
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
return pop_nodes, pop_conns

View File

@@ -1,71 +0,0 @@
from jax import Array, numpy as jnp, vmap
from core import Gene
def distance(gene: Gene, state, genome1, genome2):
return node_distance(gene, state, genome1.nodes, genome2.nodes) + \
connection_distance(gene, state, genome1.conns, genome2.conns)
def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def connection_distance(gene: Gene, state, cons1: Array, cons2: Array):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)

View File

@@ -1,319 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from core import Gene, Genome, State
from utils import rank_elements, fetch_first
from .distance import distance
from .species_info import SpeciesInfo
def update_species(state, randkey, fitness):
# update the fitness of each species
species_fitness = update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = stagnation(state, species_fitness)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_info=state.species_info[sort_indices],
center_genomes=state.center_genomes[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(state)
# crossover info
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_info.size()))
def stagnation(state, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li, _ = state.species_info.get(idx)
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
return st, sk, bf, li
spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
# TODO: Simplify the coded
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_info=species_info,
center_genomes=Genome(center_nodes, center_conns)
)
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_info.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.species_info.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.size()
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_info.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = state.genome_elitism
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(gene: Gene, state: State):
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
return i + 1, i2s, cgs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,55 +0,0 @@
from jax.tree_util import register_pytree_node_class
import numpy as np
import jax.numpy as jnp
@register_pytree_node_class
class SpeciesInfo:
def __init__(self, species_keys, best_fitness, last_improved, member_count):
self.species_keys = species_keys
self.best_fitness = best_fitness
self.last_improved = last_improved
self.member_count = member_count
@classmethod
def initialize(cls, state):
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
return cls(species_keys, best_fitness, last_improved, member_count)
def __getitem__(self, i):
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
def get(self, i):
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
def set(self, idx, value):
species_keys = self.species_keys.at[idx].set(value[0])
best_fitness = self.best_fitness.at[idx].set(value[1])
last_improved = self.last_improved.at[idx].set(value[2])
member_count = self.member_count.at[idx].set(value[3])
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
def remove(self, idx):
return self.set(idx, jnp.array([jnp.nan] * 4))
def size(self):
return self.species_keys.shape[0]
def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)