Files
tensorneat-mend/neat/operations.py
wls2002 ba369db0b2 Perfect!
Next is to connect with Evox!
2023-06-25 02:57:45 +08:00

172 lines
6.5 KiB
Python

"""
contains operations on the population: creating the next generation and population speciation.
"""
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
center_nodes, center_cons, species_keys, new_species_key_start,
jit_config):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
new_node_keys, jit_config)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
@jit
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
def find_new_centers(i, carry):
i2s, cn, cc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
return i2s, cn, cc
idx2specie, center_nodes, center_cons = \
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def body_func(carry):
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i2s, scn, scc, sk, ck = jax.lax.cond(
sk[i] == I_INT, # whether the current species is existing or not
create_new_specie, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, cn, cc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to the new species
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, sk, ck = carry
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck
def speciate_by_threshold(carry):
i, i2s, cn, cc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, center_nodes, center_cons, species_keys
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx