Files
tensorneat-mend/tensorneat/algorithm/neat/neat.py

166 lines
5.3 KiB
Python

import jax
from jax import vmap, numpy as jnp
import numpy as np
from .species import SpeciesController
from .. import BaseAlgorithm
from tensorneat.common import State
from tensorneat.genome import BaseGenome
class NEAT(BaseAlgorithm):
def __init__(
self,
genome: BaseGenome,
pop_size: int,
species_size: int = 10,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
species_fitness_func: callable = jnp.max,
):
self.genome = genome
self.pop_size = pop_size
self.species_controller = SpeciesController(
pop_size,
species_size,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
)
def setup(self, state=State()):
# setup state
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
# initialize the population
initialize_keys = jax.random.split(k1, self.pop_size)
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
state = state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
generation=jnp.float32(0),
)
# initialize species state
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
return state.update(randkey=randkey)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
state = state.update(generation=state.generation + 1)
# tell fitness to species controller
state, winner, loser, elite_mask = self.species_controller.update_species(
state,
fitness,
)
# create next population
state = self._create_next_generation(state, winner, loser, elite_mask)
# speciate the next population
state = self.species_controller.speciate(state, self.genome.execute_distance)
return state
def transform(self, state, individual):
nodes, conns = individual
return self.genome.transform(state, nodes, conns)
def forward(self, state, transformed, inputs):
return self.genome.forward(state, transformed, inputs)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
def _create_next_generation(self, state, winner, loser, elite_mask):
# find next node key for mutation
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)
def show_details(self, state, fitness):
member_count = jax.device_get(state.species.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print(
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
)