add License and pyproject.toml
This commit is contained in:
3
src/tensorneat/algorithm/__init__.py
Normal file
3
src/tensorneat/algorithm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAlgorithm
|
||||
from .neat import NEAT
|
||||
from .hyperneat import HyperNEAT
|
||||
30
src/tensorneat/algorithm/base.py
Normal file
30
src/tensorneat/algorithm/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseAlgorithm(StatefulBaseClass):
|
||||
def ask(self, state: State):
|
||||
"""require the population to be evaluated"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
"""update the state of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def show_details(self, state: State, fitness):
|
||||
"""Visualize the running details of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
2
src/tensorneat/algorithm/hyperneat/__init__.py
Normal file
2
src/tensorneat/algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
|
||||
125
src/tensorneat/algorithm/hyperneat/hyperneat.py
Normal file
125
src/tensorneat/algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .substrate import *
|
||||
from tensorneat.common import State, Act, Agg
|
||||
from tensorneat.algorithm import BaseAlgorithm, NEAT
|
||||
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
|
||||
|
||||
|
||||
class HyperNEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
substrate: BaseSubstrate,
|
||||
neat: NEAT,
|
||||
weight_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
aggregation: Callable = Agg.sum,
|
||||
activation: Callable = Act.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.standard_sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
), "Query coors of Substrate should be equal to NEAT input size"
|
||||
|
||||
self.substrate = substrate
|
||||
self.neat = neat
|
||||
self.weight_threshold = weight_threshold
|
||||
self.max_weight = max_weight
|
||||
self.hyper_genome = RecurrentGenome(
|
||||
num_inputs=substrate.num_inputs,
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNEATNode(aggregation, activation),
|
||||
conn_gene=HyperNEATConn(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
self.pop_size = neat.pop_size
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.neat.setup(state)
|
||||
state = self.substrate.setup(state)
|
||||
return self.hyper_genome.setup(state)
|
||||
|
||||
def ask(self, state):
|
||||
return self.neat.ask(state)
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = self.neat.tell(state, fitness)
|
||||
return state
|
||||
|
||||
def transform(self, state, individual):
|
||||
transformed = self.neat.transform(state, individual)
|
||||
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
|
||||
state, transformed, self.substrate.query_coors
|
||||
)
|
||||
# mute the connection with weight weight threshold
|
||||
query_res = jnp.where(
|
||||
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
|
||||
0.0,
|
||||
query_res,
|
||||
)
|
||||
|
||||
# make query res in range [-max_weight, max_weight]
|
||||
query_res = jnp.where(
|
||||
query_res > 0, query_res - self.weight_threshold, query_res
|
||||
)
|
||||
query_res = jnp.where(
|
||||
query_res < 0, query_res + self.weight_threshold, query_res
|
||||
)
|
||||
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
|
||||
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conns(query_res)
|
||||
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
|
||||
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
|
||||
return res
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.substrate.num_inputs - 1 # remove bias
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.substrate.num_outputs
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
return self.neat.show_details(state, fitness)
|
||||
|
||||
|
||||
class HyperNEATNode(BaseNode):
|
||||
def __init__(
|
||||
self,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregation = aggregation
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs)),
|
||||
)
|
||||
|
||||
|
||||
class HyperNEATConn(BaseConn):
|
||||
custom_attrs = ["weight"]
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
3
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
Normal file
3
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseSubstrate
|
||||
from .default import DefaultSubstrate
|
||||
from .full import FullSubstrate
|
||||
30
src/tensorneat/algorithm/hyperneat/substrate/base.py
Normal file
30
src/tensorneat/algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from tensorneat.common import StatefulBaseClass
|
||||
|
||||
|
||||
class BaseSubstrate(StatefulBaseClass):
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_conns(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
raise NotImplementedError
|
||||
40
src/tensorneat/algorithm/hyperneat/substrate/default.py
Normal file
40
src/tensorneat/algorithm/hyperneat/substrate/default.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseSubstrate
|
||||
from tensorneat.genome.utils import set_conn_attrs
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
|
||||
self.inputs = num_inputs
|
||||
self.outputs = num_outputs
|
||||
self.coors = jnp.array(coors)
|
||||
self.nodes = jnp.array(nodes)
|
||||
self.conns = jnp.array(conns)
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
return self.nodes
|
||||
|
||||
def make_conns(self, query_res):
|
||||
# change weight of conns
|
||||
return vmap(set_conn_attrs)(self.conns, query_res)
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
return self.coors
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.outputs
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
return self.nodes.shape[0]
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
return self.conns.shape[0]
|
||||
79
src/tensorneat/algorithm/hyperneat/substrate/full.py
Normal file
79
src/tensorneat/algorithm/hyperneat/substrate/full.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import numpy as np
|
||||
from .default import DefaultSubstrate
|
||||
|
||||
|
||||
class FullSubstrate(DefaultSubstrate):
|
||||
def __init__(
|
||||
self,
|
||||
input_coors=((-1, -1), (0, -1), (1, -1)),
|
||||
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||
output_coors=((0, 1),),
|
||||
):
|
||||
query_coors, nodes, conns = analysis_substrate(
|
||||
input_coors, output_coors, hidden_coors
|
||||
)
|
||||
super().__init__(len(input_coors), len(output_coors), query_coors, nodes, conns)
|
||||
|
||||
|
||||
def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||
input_coors = np.array(input_coors)
|
||||
output_coors = np.array(output_coors)
|
||||
hidden_coors = np.array(hidden_coors)
|
||||
|
||||
cd = input_coors.shape[1] # coordinate dimensions
|
||||
si = input_coors.shape[0] # input coordinate size
|
||||
so = output_coors.shape[0] # output coordinate size
|
||||
sh = hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
input_idx, hidden_idx, input_coors, hidden_coors
|
||||
)
|
||||
query_coors[0 : si * sh, :] = aux_coors
|
||||
correspond_keys[0 : si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
hidden_idx, hidden_idx, hidden_coors, hidden_coors
|
||||
)
|
||||
query_coors[si * sh : si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh : si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
hidden_idx, output_idx, hidden_coors, output_coors
|
||||
)
|
||||
query_coors[si * sh + sh * sh :, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh :, :] = aux_keys
|
||||
|
||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||
conns = np.zeros(
|
||||
(correspond_keys.shape[0], 3), dtype=np.float32
|
||||
) # input_idx, output_idx, weight
|
||||
conns[:, :2] = correspond_keys
|
||||
|
||||
return query_coors, nodes, conns
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
2
src/tensorneat/algorithm/neat/__init__.py
Normal file
2
src/tensorneat/algorithm/neat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
167
src/tensorneat/algorithm/neat/neat.py
Normal file
167
src/tensorneat/algorithm/neat/neat.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .. import BaseAlgorithm
|
||||
from tensorneat.common import State
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
pop_size: int,
|
||||
species_size: int = 10,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.1,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 2.0,
|
||||
species_fitness_func: Callable = jnp.max,
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_controller = SpeciesController(
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
)
|
||||
|
||||
def setup(self, state=State()):
|
||||
# setup state
|
||||
state = self.genome.setup(state)
|
||||
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(k1, self.pop_size)
|
||||
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
state = state.register(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
# initialize species state
|
||||
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
|
||||
|
||||
return state.update(randkey=randkey)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = state.update(generation=state.generation + 1)
|
||||
|
||||
# tell fitness to species controller
|
||||
state, winner, loser, elite_mask = self.species_controller.update_species(
|
||||
state,
|
||||
fitness,
|
||||
)
|
||||
|
||||
# create next population
|
||||
state = self._create_next_generation(state, winner, loser, elite_mask)
|
||||
|
||||
# speciate the next population
|
||||
state = self.species_controller.speciate(state, self.genome.execute_distance)
|
||||
|
||||
return state
|
||||
|
||||
def transform(self, state, individual):
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
return self.genome.forward(state, transformed, inputs)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.genome.num_outputs
|
||||
|
||||
def _create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
# find next node key for mutation
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
member_count = jax.device_get(state.species.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
|
||||
print(
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
)
|
||||
537
src/tensorneat/algorithm/neat/species.py
Normal file
537
src/tensorneat/algorithm/neat/species.py
Normal file
@@ -0,0 +1,537 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
)
|
||||
|
||||
|
||||
class SpeciesController(StatefulBaseClass):
|
||||
def __init__(
|
||||
self,
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
):
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
self.species_arange = np.arange(self.species_size)
|
||||
self.max_stagnation = max_stagnation
|
||||
self.species_elitism = species_elitism
|
||||
self.spawn_number_change_rate = spawn_number_change_rate
|
||||
self.genome_elitism = genome_elitism
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.species_fitness_func = species_fitness_func
|
||||
|
||||
def setup(self, state, first_nodes, first_conns):
|
||||
# the unique index (primary key) for each species
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the best fitness of each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the last 1 that the species improved
|
||||
last_improved = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the number of members of each species
|
||||
member_count = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the species index of each individual
|
||||
idx2species = jnp.zeros(self.pop_size)
|
||||
|
||||
# nodes for each center genome of each species
|
||||
center_nodes = jnp.full(
|
||||
(self.species_size, *first_nodes.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
# connections for each center genome of each species
|
||||
center_conns = jnp.full(
|
||||
(self.species_size, *first_conns.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
species_keys = species_keys.at[0].set(0)
|
||||
best_fitness = best_fitness.at[0].set(-jnp.inf)
|
||||
last_improved = last_improved.at[0].set(0)
|
||||
member_count = member_count.at[0].set(self.pop_size)
|
||||
center_nodes = center_nodes.at[0].set(first_nodes)
|
||||
center_conns = center_conns.at[0].set(first_conns)
|
||||
|
||||
species_state = State(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=jnp.float32(1), # 0 is reserved for the first species
|
||||
)
|
||||
|
||||
return state.register(species=species_state)
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
species_state = state.species
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = self._update_species_fitness(species_state, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_state, species_fitness = self._stagnation(
|
||||
species_state, species_fitness, state.generation
|
||||
)
|
||||
|
||||
# sort species_info by their fitness. (also push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
|
||||
|
||||
species_state = species_state.update(
|
||||
species_keys=species_state.species_keys[sort_indices],
|
||||
best_fitness=species_state.best_fitness[sort_indices],
|
||||
last_improved=species_state.last_improved[sort_indices],
|
||||
member_count=species_state.member_count[sort_indices],
|
||||
center_nodes=species_state.center_nodes[sort_indices],
|
||||
center_conns=species_state.center_conns[sort_indices],
|
||||
)
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = self._cal_spawn_numbers(species_state)
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self._create_crossover_pair(
|
||||
species_state, k1, spawn_number, fitness
|
||||
)
|
||||
|
||||
return (
|
||||
state.update(randkey=k2, species=species_state),
|
||||
winner,
|
||||
loser,
|
||||
elite_mask,
|
||||
)
|
||||
|
||||
def _update_species_fitness(self, species_state, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = jnp.where(
|
||||
species_state.idx2species == species_state.species_keys[idx],
|
||||
fitness,
|
||||
-jnp.inf,
|
||||
)
|
||||
val = self.species_fitness_func(s_fitness)
|
||||
return val
|
||||
|
||||
return vmap(aux_func)(self.species_arange)
|
||||
|
||||
def _stagnation(self, species_state, species_fitness, generation):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
|
||||
# not better than the best fitness of the species
|
||||
# for a long time
|
||||
st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
|
||||
generation - species_state.last_improved[idx] > self.max_stagnation
|
||||
)
|
||||
|
||||
# update last_improved and best_fitness
|
||||
# whether better than the best fitness of the species
|
||||
li, bf = jax.lax.cond(
|
||||
species_fitness[idx] > species_state.best_fitness[idx],
|
||||
lambda: (generation, species_fitness[idx]), # update
|
||||
lambda: (
|
||||
species_state.last_improved[idx],
|
||||
species_state.best_fitness[idx],
|
||||
), # not update
|
||||
)
|
||||
|
||||
return st, bf, li
|
||||
|
||||
spe_st, best_fitness, last_improved = vmap(check_stagnation)(
|
||||
self.species_arange
|
||||
)
|
||||
|
||||
# update species state
|
||||
species_state = species_state.update(
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
)
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
spe_st = jnp.where(
|
||||
species_rank < self.species_elitism, False, spe_st
|
||||
) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
def update_func(idx):
|
||||
return jax.lax.cond(
|
||||
spe_st[idx],
|
||||
lambda: (
|
||||
jnp.nan, # species_key
|
||||
jnp.nan, # best_fitness
|
||||
jnp.nan, # last_improved
|
||||
jnp.nan, # member_count
|
||||
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
|
||||
jnp.full_like(species_state.center_conns[idx], jnp.nan),
|
||||
-jnp.inf, # species_fitness
|
||||
), # stagnation species
|
||||
lambda: (
|
||||
species_state.species_keys[idx],
|
||||
species_state.best_fitness[idx],
|
||||
species_state.last_improved[idx],
|
||||
species_state.member_count[idx],
|
||||
species_state.center_nodes[idx],
|
||||
species_state.center_conns[idx],
|
||||
species_fitness[idx],
|
||||
), # not stagnation species
|
||||
)
|
||||
|
||||
(
|
||||
species_keys,
|
||||
best_fitness,
|
||||
last_improved,
|
||||
member_count,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
species_fitness,
|
||||
) = vmap(update_func)(self.species_arange)
|
||||
|
||||
return (
|
||||
species_state.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
species_fitness,
|
||||
)
|
||||
|
||||
def _cal_spawn_numbers(self, species_state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
species_keys = species_state.species_keys
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_keys)
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (
|
||||
(valid_species_num + 1) * valid_species_num / 2
|
||||
) # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(
|
||||
is_species_valid, spawn_number_rate, 0
|
||||
) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(
|
||||
spawn_number_rate * self.pop_size
|
||||
) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers for a species
|
||||
previous_size = species_state.member_count
|
||||
spawn_number = (
|
||||
previous_size
|
||||
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
)
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = self.pop_size - jnp.sum(spawn_number)
|
||||
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
|
||||
return spawn_number
|
||||
|
||||
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
||||
s_idx = self.species_arange
|
||||
p_idx = jnp.arange(self.pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
# choose parents from the in the same species
|
||||
# key -> randkey, idx -> the idx of current species
|
||||
|
||||
members = species_state.idx2species == species_state.species_keys[idx]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
survive_size = jnp.floor(self.survival_threshold * members_num).astype(
|
||||
jnp.int32
|
||||
)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(
|
||||
key,
|
||||
sorted_member_indices,
|
||||
shape=(2, self.pop_size),
|
||||
replace=True,
|
||||
p=select_pro,
|
||||
)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < self.genome_elitism, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
# choose parents to crossover in each species
|
||||
# fas, mas, elites: (self.species_size, self.pop_size)
|
||||
# fas -> father indices, mas -> mother indices, elites -> whether elite or not
|
||||
fas, mas, elites = vmap(aux_func)(
|
||||
jax.random.split(randkey, self.species_size), s_idx
|
||||
)
|
||||
|
||||
# merge choosen parents from each species into one array
|
||||
# winner, loser, elite_mask: (self.pop_size)
|
||||
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return (
|
||||
fas[loc, idx_in_species],
|
||||
mas[loc, idx_in_species],
|
||||
elites[loc, idx_in_species],
|
||||
)
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state, genome_distance_func: Callable):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(
|
||||
genome_distance_func, in_axes=(None, None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
idx2species = jnp.full(
|
||||
(self.pop_size,), jnp.nan
|
||||
) # NaN means not assigned to any species
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
# i, idx2species, center_nodes, center_conns, o2c_distances
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
return (i < self.species_size) & (
|
||||
~jnp.isnan(state.species.species_keys[i])
|
||||
) # current species is existing
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
|
||||
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
|
||||
return i + 1, i2s, cns, ccs, o2c
|
||||
|
||||
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(
|
||||
0,
|
||||
idx2species,
|
||||
state.species.center_nodes,
|
||||
state.species.center_conns,
|
||||
o2c_distances,
|
||||
),
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
species=state.species.update(
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
)
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < self.species_size
|
||||
return not_reach_species_upper_bounds & (
|
||||
current_species_existed | not_all_assigned
|
||||
)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cns, ccs, sk, o2c, nsk),
|
||||
)
|
||||
|
||||
return i + 1, i2s, cns, ccs, sk, o2c, nsk
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk) # nsk -> next species key
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
cns = cns.at[i].set(state.pop_nodes[idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[idx])
|
||||
|
||||
# find the members for the new species
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
|
||||
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cns, ccs, sk, o2c, nsk
|
||||
|
||||
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
close_enough_mask = o2p_distance < self.compatibility_threshold
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
|
||||
mask = close_enough_mask & catchable_mask
|
||||
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
|
||||
return i2s, o2c
|
||||
|
||||
# update idx2species
|
||||
(
|
||||
_,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
species_keys,
|
||||
_,
|
||||
next_species_key,
|
||||
) = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(
|
||||
0,
|
||||
state.species.idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
state.species.species_keys,
|
||||
o2c_distances,
|
||||
state.species.next_species_key,
|
||||
),
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
||||
last_improved = jnp.where(
|
||||
new_created_mask, state.generation, state.species.last_improved
|
||||
)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
return jax.lax.cond(
|
||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||
lambda: jnp.nan, # nan
|
||||
lambda: jnp.sum(
|
||||
idx2species == species_keys[idx], dtype=jnp.float32
|
||||
), # count members
|
||||
)
|
||||
|
||||
member_count = vmap(count_members)(self.species_arange)
|
||||
|
||||
species_state = state.species.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
return state.update(
|
||||
species=species_state,
|
||||
)
|
||||
Reference in New Issue
Block a user