remove create_func....
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from core import Gene, Genome
|
||||
from core import Gene, Genome, State
|
||||
from utils import rank_elements, fetch_first
|
||||
from .distance import create_distance
|
||||
from .distance import distance
|
||||
from .species_info import SpeciesInfo
|
||||
|
||||
|
||||
@@ -170,154 +168,149 @@ def create_crossover_pair(state, randkey, spawn_number, fitness):
|
||||
return winner, loser, elite_mask
|
||||
|
||||
|
||||
def create_speciate(gene_type: Type[Gene]):
|
||||
distance = create_distance(gene_type)
|
||||
def speciate(gene: Gene, state: State):
|
||||
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
|
||||
|
||||
def speciate(state):
|
||||
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
|
||||
# idx to specie key
|
||||
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
|
||||
# idx to specie key
|
||||
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
|
||||
|
||||
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
||||
|
||||
distances = o2p_distance_func(state, cgs[i], state.pop_genomes)
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cgs = cgs.set(i, state.pop_genomes[closest_idx])
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cgs = cgs.set(i, state.pop_genomes[closest_idx])
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
return i + 1, i2s, cgs, o2c
|
||||
|
||||
return i + 1, i2s, cgs, o2c
|
||||
_, idx2species, center_genomes, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
|
||||
|
||||
_, idx2species, center_genomes, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cgs, sk, o2c, nsk)
|
||||
)
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
def create_new_species(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cgs, sk, o2c, nsk)
|
||||
)
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk)
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
# update center genomes
|
||||
cgs = cgs.set(i, state.pop_genomes[idx])
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk)
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
# update center genomes
|
||||
cgs = cgs.set(i, state.pop_genomes[idx])
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
# turn to next species
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
||||
close_enough_mask = o2p_distance < state.compatibility_threshold
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
# jax.debug.print("{}", o2p_distance)
|
||||
mask = close_enough_mask & cacheable_mask
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
|
||||
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
|
||||
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes)
|
||||
close_enough_mask = o2p_distance < state.compatibility_threshold
|
||||
return i2s, o2c
|
||||
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
# jax.debug.print("{}", o2p_distance)
|
||||
mask = close_enough_mask & cacheable_mask
|
||||
# update idx2species
|
||||
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
|
||||
state.next_species_key)
|
||||
)
|
||||
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
|
||||
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
|
||||
|
||||
return i2s, o2c
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
key = species_keys[idx]
|
||||
count = jnp.sum(idx2species == key, dtype=jnp.float32)
|
||||
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||
|
||||
# update idx2species
|
||||
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
|
||||
)
|
||||
return count
|
||||
|
||||
member_count = vmap(count_members)(jnp.arange(species_size))
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
|
||||
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
key = species_keys[idx]
|
||||
count = jnp.sum(idx2species == key, dtype=jnp.float32)
|
||||
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||
|
||||
return count
|
||||
|
||||
member_count = vmap(count_members)(jnp.arange(species_size))
|
||||
|
||||
return state.update(
|
||||
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
next_species_key=next_species_key
|
||||
)
|
||||
|
||||
return speciate
|
||||
return state.update(
|
||||
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
next_species_key=next_species_key
|
||||
)
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
|
||||
Reference in New Issue
Block a user