remove create_func....
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .operations import update_species, create_speciate
|
||||
from .species_info import SpeciesInfo
|
||||
from .operations import update_species, speciate
|
||||
|
||||
@@ -1,73 +1,71 @@
|
||||
from typing import Type
|
||||
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from core import Gene
|
||||
|
||||
|
||||
def create_distance(gene_type: Type[Gene]):
|
||||
def node_distance(state, nodes1: Array, nodes2: Array):
|
||||
"""
|
||||
Calculate the distance between nodes of two genomes.
|
||||
"""
|
||||
# statistics nodes count of two genomes
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
def distance(gene: Gene, state, genome1, genome2):
|
||||
return node_distance(gene, state, genome1.nodes, genome2.nodes) + \
|
||||
connection_distance(gene, state, genome1.conns, genome2.conns)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array):
|
||||
"""
|
||||
Calculate the distance between nodes of two genomes.
|
||||
"""
|
||||
# statistics nodes count of two genomes
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = vmap(gene_type.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
def connection_distance(state, cons1: Array, cons2: Array):
|
||||
"""
|
||||
Calculate the distance between connections of two genomes.
|
||||
Similar process as node_distance.
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
|
||||
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = vmap(gene_type.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
def connection_distance(gene: Gene, state, cons1: Array, cons2: Array):
|
||||
"""
|
||||
Calculate the distance between connections of two genomes.
|
||||
Similar process as node_distance.
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
def distance(state, genome1, genome2):
|
||||
return node_distance(state, genome1.nodes, genome2.nodes) + connection_distance(state, genome1.conns, genome2.conns)
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
return distance
|
||||
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from core import Gene, Genome
|
||||
from core import Gene, Genome, State
|
||||
from utils import rank_elements, fetch_first
|
||||
from .distance import create_distance
|
||||
from .distance import distance
|
||||
from .species_info import SpeciesInfo
|
||||
|
||||
|
||||
@@ -170,154 +168,149 @@ def create_crossover_pair(state, randkey, spawn_number, fitness):
|
||||
return winner, loser, elite_mask
|
||||
|
||||
|
||||
def create_speciate(gene_type: Type[Gene]):
|
||||
distance = create_distance(gene_type)
|
||||
def speciate(gene: Gene, state: State):
|
||||
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
|
||||
|
||||
def speciate(state):
|
||||
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
|
||||
# idx to specie key
|
||||
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
|
||||
# idx to specie key
|
||||
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
|
||||
|
||||
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, o2c = carry
|
||||
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
||||
|
||||
distances = o2p_distance_func(state, cgs[i], state.pop_genomes)
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cgs = cgs.set(i, state.pop_genomes[closest_idx])
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cgs = cgs.set(i, state.pop_genomes[closest_idx])
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
return i + 1, i2s, cgs, o2c
|
||||
|
||||
return i + 1, i2s, cgs, o2c
|
||||
_, idx2species, center_genomes, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
|
||||
|
||||
_, idx2species, center_genomes, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cgs, sk, o2c, nsk)
|
||||
)
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
def create_new_species(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cgs, sk, o2c, nsk)
|
||||
)
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk)
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
# update center genomes
|
||||
cgs = cgs.set(i, state.pop_genomes[idx])
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk)
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
|
||||
# update center genomes
|
||||
cgs = cgs.set(i, state.pop_genomes[idx])
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
# turn to next species
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cgs, sk, o2c, nsk = carry
|
||||
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
|
||||
close_enough_mask = o2p_distance < state.compatibility_threshold
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
# jax.debug.print("{}", o2p_distance)
|
||||
mask = close_enough_mask & cacheable_mask
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cgs, sk, o2c, nsk
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
|
||||
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
|
||||
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes)
|
||||
close_enough_mask = o2p_distance < state.compatibility_threshold
|
||||
return i2s, o2c
|
||||
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
# jax.debug.print("{}", o2p_distance)
|
||||
mask = close_enough_mask & cacheable_mask
|
||||
# update idx2species
|
||||
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
|
||||
state.next_species_key)
|
||||
)
|
||||
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
|
||||
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
|
||||
|
||||
return i2s, o2c
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
key = species_keys[idx]
|
||||
count = jnp.sum(idx2species == key, dtype=jnp.float32)
|
||||
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||
|
||||
# update idx2species
|
||||
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
|
||||
)
|
||||
return count
|
||||
|
||||
member_count = vmap(count_members)(jnp.arange(species_size))
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
|
||||
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
key = species_keys[idx]
|
||||
count = jnp.sum(idx2species == key, dtype=jnp.float32)
|
||||
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||
|
||||
return count
|
||||
|
||||
member_count = vmap(count_members)(jnp.arange(species_size))
|
||||
|
||||
return state.update(
|
||||
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
next_species_key=next_species_key
|
||||
)
|
||||
|
||||
return speciate
|
||||
return state.update(
|
||||
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
next_species_key=next_species_key
|
||||
)
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
|
||||
@@ -2,6 +2,7 @@ from jax.tree_util import register_pytree_node_class
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class SpeciesInfo:
|
||||
|
||||
@@ -44,7 +45,6 @@ class SpeciesInfo:
|
||||
def size(self):
|
||||
return self.species_keys.shape[0]
|
||||
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
|
||||
aux_data = None
|
||||
|
||||
Reference in New Issue
Block a user