add License and pyproject.toml

This commit is contained in:
root
2024-07-11 23:56:06 +08:00
parent e2869c7562
commit 5fdf7b29bc
60 changed files with 71 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .base import BaseAlgorithm
from .neat import NEAT
from .hyperneat import HyperNEAT

View File

@@ -0,0 +1,30 @@
from tensorneat.common import State, StatefulBaseClass
class BaseAlgorithm(StatefulBaseClass):
def ask(self, state: State):
"""require the population to be evaluated"""
raise NotImplementedError
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
raise NotImplementedError
def transform(self, state, individual):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, state, transformed, inputs):
raise NotImplementedError
def show_details(self, state: State, fitness):
"""Visualize the running details of the algorithm"""
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate

View File

@@ -0,0 +1,125 @@
from typing import Callable
import jax
from jax import vmap, numpy as jnp
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
), "Query coors of Substrate should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.weight_threshold = weight_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNEATNode(aggregation, activation),
conn_gene=HyperNEATConn(),
activate_time=activate_time,
output_transform=output_transform,
)
self.pop_size = neat.pop_size
def setup(self, state=State()):
state = self.neat.setup(state)
state = self.substrate.setup(state)
return self.hyper_genome.setup(state)
def ask(self, state):
return self.neat.ask(state)
def tell(self, state, fitness):
state = self.neat.tell(state, fitness)
return state
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
state, transformed, self.substrate.query_coors
)
# mute the connection with weight weight threshold
query_res = jnp.where(
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
0.0,
query_res,
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(
query_res > 0, query_res - self.weight_threshold, query_res
)
query_res = jnp.where(
query_res < 0, query_res + self.weight_threshold, query_res
)
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conns(query_res)
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, transformed, inputs):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
return res
@property
def num_inputs(self):
return self.substrate.num_inputs - 1 # remove bias
@property
def num_outputs(self):
return self.substrate.num_outputs
def show_details(self, state, fitness):
return self.neat.show_details(state, fitness)
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
activation=Act.sigmoid,
):
super().__init__()
self.aggregation = aggregation
self.activation = activation
def forward(self, state, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs)),
)
class HyperNEATConn(BaseConn):
custom_attrs = ["weight"]
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,3 @@
from .base import BaseSubstrate
from .default import DefaultSubstrate
from .full import FullSubstrate

View File

@@ -0,0 +1,30 @@
from tensorneat.common import StatefulBaseClass
class BaseSubstrate(StatefulBaseClass):
def make_nodes(self, query_res):
raise NotImplementedError
def make_conns(self, query_res):
raise NotImplementedError
@property
def query_coors(self):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def nodes_cnt(self):
raise NotImplementedError
@property
def conns_cnt(self):
raise NotImplementedError

View File

@@ -0,0 +1,40 @@
from jax import vmap, numpy as jnp
from .base import BaseSubstrate
from tensorneat.genome.utils import set_conn_attrs
class DefaultSubstrate(BaseSubstrate):
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
self.inputs = num_inputs
self.outputs = num_outputs
self.coors = jnp.array(coors)
self.nodes = jnp.array(nodes)
self.conns = jnp.array(conns)
def make_nodes(self, query_res):
return self.nodes
def make_conns(self, query_res):
# change weight of conns
return vmap(set_conn_attrs)(self.conns, query_res)
@property
def query_coors(self):
return self.coors
@property
def num_inputs(self):
return self.inputs
@property
def num_outputs(self):
return self.outputs
@property
def nodes_cnt(self):
return self.nodes.shape[0]
@property
def conns_cnt(self):
return self.conns.shape[0]

View File

@@ -0,0 +1,79 @@
import numpy as np
from .default import DefaultSubstrate
class FullSubstrate(DefaultSubstrate):
def __init__(
self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(
input_coors, output_coors, hidden_coors
)
super().__init__(len(input_coors), len(output_coors), query_coors, nodes, conns)
def analysis_substrate(input_coors, output_coors, hidden_coors):
input_coors = np.array(input_coors)
output_coors = np.array(output_coors)
hidden_coors = np.array(hidden_coors)
cd = input_coors.shape[1] # coordinate dimensions
si = input_coors.shape[0] # input coordinate size
so = output_coors.shape[0] # output coordinate size
sh = hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(
input_idx, hidden_idx, input_coors, hidden_coors
)
query_coors[0 : si * sh, :] = aux_coors
correspond_keys[0 : si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(
hidden_idx, hidden_idx, hidden_coors, hidden_coors
)
query_coors[si * sh : si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh : si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(
hidden_idx, output_idx, hidden_coors, output_coors
)
query_coors[si * sh + sh * sh :, :] = aux_coors
correspond_keys[si * sh + sh * sh :, :] = aux_keys
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros(
(correspond_keys.shape[0], 3), dtype=np.float32
) # input_idx, output_idx, weight
conns[:, :2] = correspond_keys
return query_coors, nodes, conns
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -0,0 +1,2 @@
from .species import *
from .neat import NEAT

View File

@@ -0,0 +1,167 @@
from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
from .species import SpeciesController
from .. import BaseAlgorithm
from tensorneat.common import State
from tensorneat.genome import BaseGenome
class NEAT(BaseAlgorithm):
def __init__(
self,
genome: BaseGenome,
pop_size: int,
species_size: int = 10,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.1,
min_species_size: int = 1,
compatibility_threshold: float = 2.0,
species_fitness_func: Callable = jnp.max,
):
self.genome = genome
self.pop_size = pop_size
self.species_controller = SpeciesController(
pop_size,
species_size,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
)
def setup(self, state=State()):
# setup state
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
# initialize the population
initialize_keys = jax.random.split(k1, self.pop_size)
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
state = state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
generation=jnp.float32(0),
)
# initialize species state
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
return state.update(randkey=randkey)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
state = state.update(generation=state.generation + 1)
# tell fitness to species controller
state, winner, loser, elite_mask = self.species_controller.update_species(
state,
fitness,
)
# create next population
state = self._create_next_generation(state, winner, loser, elite_mask)
# speciate the next population
state = self.species_controller.speciate(state, self.genome.execute_distance)
return state
def transform(self, state, individual):
nodes, conns = individual
return self.genome.transform(state, nodes, conns)
def forward(self, state, transformed, inputs):
return self.genome.forward(state, transformed, inputs)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
def _create_next_generation(self, state, winner, loser, elite_mask):
# find next node key for mutation
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)
def show_details(self, state, fitness):
member_count = jax.device_get(state.species.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print(
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
)

View File

@@ -0,0 +1,537 @@
from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
from tensorneat.common import (
State,
StatefulBaseClass,
rank_elements,
argmin_with_mask,
fetch_first,
)
class SpeciesController(StatefulBaseClass):
def __init__(
self,
pop_size,
species_size,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
):
self.pop_size = pop_size
self.species_size = species_size
self.species_arange = np.arange(self.species_size)
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
self.genome_elitism = genome_elitism
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_fitness_func = species_fitness_func
def setup(self, state, first_nodes, first_conns):
# the unique index (primary key) for each species
species_keys = jnp.full((self.species_size,), jnp.nan)
# the best fitness of each species
best_fitness = jnp.full((self.species_size,), jnp.nan)
# the last 1 that the species improved
last_improved = jnp.full((self.species_size,), jnp.nan)
# the number of members of each species
member_count = jnp.full((self.species_size,), jnp.nan)
# the species index of each individual
idx2species = jnp.zeros(self.pop_size)
# nodes for each center genome of each species
center_nodes = jnp.full(
(self.species_size, *first_nodes.shape),
jnp.nan,
)
# connections for each center genome of each species
center_conns = jnp.full(
(self.species_size, *first_conns.shape),
jnp.nan,
)
species_keys = species_keys.at[0].set(0)
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(first_nodes)
center_conns = center_conns.at[0].set(first_conns)
species_state = State(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=jnp.float32(1), # 0 is reserved for the first species
)
return state.register(species=species_state)
def update_species(self, state, fitness):
species_state = state.species
# update the fitness of each species
species_fitness = self._update_species_fitness(species_state, fitness)
# stagnation species
species_state, species_fitness = self._stagnation(
species_state, species_fitness, state.generation
)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
species_state = species_state.update(
species_keys=species_state.species_keys[sort_indices],
best_fitness=species_state.best_fitness[sort_indices],
last_improved=species_state.last_improved[sort_indices],
member_count=species_state.member_count[sort_indices],
center_nodes=species_state.center_nodes[sort_indices],
center_conns=species_state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = self._cal_spawn_numbers(species_state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self._create_crossover_pair(
species_state, k1, spawn_number, fitness
)
return (
state.update(randkey=k2, species=species_state),
winner,
loser,
elite_mask,
)
def _update_species_fitness(self, species_state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(
species_state.idx2species == species_state.species_keys[idx],
fitness,
-jnp.inf,
)
val = self.species_fitness_func(s_fitness)
return val
return vmap(aux_func)(self.species_arange)
def _stagnation(self, species_state, species_fitness, generation):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def check_stagnation(idx):
# determine whether the species stagnation
# not better than the best fitness of the species
# for a long time
st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
generation - species_state.last_improved[idx] > self.max_stagnation
)
# update last_improved and best_fitness
# whether better than the best fitness of the species
li, bf = jax.lax.cond(
species_fitness[idx] > species_state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (
species_state.last_improved[idx],
species_state.best_fitness[idx],
), # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = vmap(check_stagnation)(
self.species_arange
)
# update species state
species_state = species_state.update(
best_fitness=best_fitness,
last_improved=last_improved,
)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(
species_rank < self.species_elitism, False, spe_st
) # elitism never stagnation
# set stagnation species to nan
def update_func(idx):
return jax.lax.cond(
spe_st[idx],
lambda: (
jnp.nan, # species_key
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
jnp.full_like(species_state.center_conns[idx], jnp.nan),
-jnp.inf, # species_fitness
), # stagnation species
lambda: (
species_state.species_keys[idx],
species_state.best_fitness[idx],
species_state.last_improved[idx],
species_state.member_count[idx],
species_state.center_nodes[idx],
species_state.center_conns[idx],
species_fitness[idx],
), # not stagnation species
)
(
species_keys,
best_fitness,
last_improved,
member_count,
center_nodes,
center_conns,
species_fitness,
) = vmap(update_func)(self.species_arange)
return (
species_state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_nodes=center_nodes,
center_conns=center_conns,
),
species_fitness,
)
def _cal_spawn_numbers(self, species_state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = species_state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (
(valid_species_num + 1) * valid_species_num / 2
) # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(
is_species_valid, spawn_number_rate, 0
) # set invalid species to 0
target_spawn_number = jnp.floor(
spawn_number_rate * self.pop_size
) # calculate member
# Avoid too much variation of numbers for a species
previous_size = species_state.member_count
spawn_number = (
previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
# choose parents from the in the same species
# key -> randkey, idx -> the idx of current species
members = species_state.idx2species == species_state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
survive_size = jnp.floor(self.survival_threshold * members_num).astype(
jnp.int32
)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(
key,
sorted_member_indices,
shape=(2, self.pop_size),
replace=True,
p=select_pro,
)
# elite
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
# choose parents to crossover in each species
# fas, mas, elites: (self.species_size, self.pop_size)
# fas -> father indices, mas -> mother indices, elites -> whether elite or not
fas, mas, elites = vmap(aux_func)(
jax.random.split(randkey, self.species_size), s_idx
)
# merge choosen parents from each species into one array
# winner, loser, elite_mask: (self.pop_size)
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return (
fas[loc, idx_in_species],
mas[loc, idx_in_species],
elites[loc, idx_in_species],
)
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(self, state, genome_distance_func: Callable):
# prepare distance functions
o2p_distance_func = vmap(
genome_distance_func, in_axes=(None, None, None, 0, 0)
) # one to population
# idx to specie key
idx2species = jnp.full(
(self.pop_size,), jnp.nan
) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, o2c_distances
i, i2s, cns, ccs, o2c = carry
return (i < self.species_size) & (
~jnp.isnan(state.species.species_keys[i])
) # current species is existing
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cns, ccs, o2c
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
cond_func,
body_func,
(
0,
idx2species,
state.species.center_nodes,
state.species.center_conns,
o2c_distances,
),
)
state = state.update(
species=state.species.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
),
)
# part 2: assign members to each species
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
i, i2s, cns, ccs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < self.species_size
return not_reach_species_upper_bounds & (
current_species_existed | not_all_assigned
)
def body_func(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cns, ccs, sk, o2c, nsk),
)
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk) # nsk -> next species key
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.at[i].set(state.pop_nodes[idx])
ccs = ccs.at[i].set(state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
# turn to next species
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
)
close_enough_mask = o2p_distance < self.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
mask = close_enough_mask & catchable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
(
_,
idx2species,
center_nodes,
center_conns,
species_keys,
_,
next_species_key,
) = jax.lax.while_loop(
cond_func,
body_func,
(
0,
state.species.idx2species,
center_nodes,
center_conns,
state.species.species_keys,
o2c_distances,
state.species.next_species_key,
),
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
last_improved = jnp.where(
new_created_mask, state.generation, state.species.last_improved
)
# update members count
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda: jnp.nan, # nan
lambda: jnp.sum(
idx2species == species_keys[idx], dtype=jnp.float32
), # count members
)
member_count = vmap(count_members)(self.species_arange)
species_state = state.species.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=next_species_key,
)
return state.update(
species=species_state,
)