finish jit-able speciate function
next time i'll create a new branch
This commit is contained in:
@@ -76,6 +76,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
|
||||
@@ -1,4 +1,109 @@
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
|
||||
from .genome import distance
|
||||
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
||||
|
||||
|
||||
@jit
|
||||
def jitable_speciate():
|
||||
pass
|
||||
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
||||
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
||||
):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
||||
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None)
|
||||
)
|
||||
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
||||
|
||||
def continue_execute_while(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def deal_with_each_center_genome(carry):
|
||||
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
||||
|
||||
i2s, scn, scc = jax.lax.cond(
|
||||
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
||||
create_new_specie, # if not valid, create a new specie
|
||||
update_exist_specie, # if valid, update the specie
|
||||
(i, i2s, scn, scc)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to new specie
|
||||
i2s = i2s.at[idx].set(i)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||
return i2s, scn, scc
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# update new center
|
||||
i2s = i2s.at[idx].set(i)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||
return i2s, scn, scc
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
||||
close_enough_mask = o2p_distance < compatibility_threshold
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
|
||||
return i2s, scn, scc
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
|
||||
continue_execute_while,
|
||||
deal_with_each_center_genome,
|
||||
(0, idx2specie, spe_center_nodes, spe_center_cons)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
|
||||
|
||||
return idx2specie, spe_center_nodes, spe_center_cons
|
||||
|
||||
Reference in New Issue
Block a user