Files
tensorneat-mend/algorithms/neat/jitable_speciate.py
wls2002 6006f92f3f finish jit-able speciate function
next time i'll create a new branch
2023-05-12 19:26:02 +08:00

110 lines
4.0 KiB
Python

from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def continue_execute_while(carry):
i, i2s, scn, scc = carry
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc)
)
return i + 1, i2s, scn, scc
def create_new_specie(carry):
i, i2s, scn, scc = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def update_exist_specie(carry):
i, i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# update new center
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def speciate_by_threshold(carry):
i, i2s, scn, scc = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
return i2s, scn, scc
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons