remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,11 +1,9 @@
from typing import Type
import jax
from jax import numpy as jnp, vmap
from core import Gene, Genome
from core import Gene, Genome, State
from utils import rank_elements, fetch_first
from .distance import create_distance
from .distance import distance
from .species_info import SpeciesInfo
@@ -170,154 +168,149 @@ def create_crossover_pair(state, randkey, spawn_number, fitness):
return winner, loser, elite_mask
def create_speciate(gene_type: Type[Gene]):
distance = create_distance(gene_type)
def speciate(gene: Gene, state: State):
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
def speciate(state):
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
distances = o2p_distance_func(state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cgs, o2c
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
)
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
return i + 1, i2s, cgs, sk, o2c, nsk
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
return i + 1, i2s, cgs, sk, o2c, nsk
# update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
# update species info
i2s = jnp.where(mask, sk[i], i2s)
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
return i2s, o2c
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
return i2s, o2c
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
return speciate
return state.update(
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
def argmin_with_mask(arr, mask):