131 lines
5.0 KiB
Python
131 lines
5.0 KiB
Python
import jax
|
|
from jax import numpy as jnp
|
|
import numpy as np
|
|
|
|
from config import Config
|
|
from core import Algorithm, State, Gene, Genome
|
|
from .ga import create_next_generation
|
|
from .species import SpeciesInfo, update_species, speciate
|
|
|
|
|
|
class NEAT(Algorithm):
|
|
|
|
def __init__(self, config: Config, gene: Gene):
|
|
self.config = config
|
|
self.gene = gene
|
|
|
|
self.forward_func = None
|
|
self.tell_func = None
|
|
|
|
def setup(self, randkey, state: State = State()):
|
|
"""initialize the state of the algorithm"""
|
|
|
|
input_idx = np.arange(self.config.neat.inputs)
|
|
output_idx = np.arange(self.config.neat.inputs,
|
|
self.config.neat.inputs + self.config.neat.outputs)
|
|
|
|
state = state.update(
|
|
P=self.config.basic.pop_size,
|
|
N=self.config.neat.maximum_nodes,
|
|
C=self.config.neat.maximum_conns,
|
|
S=self.config.neat.maximum_species,
|
|
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
|
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
|
max_stagnation=self.config.neat.max_stagnation,
|
|
species_elitism=self.config.neat.species_elitism,
|
|
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
|
|
genome_elitism=self.config.neat.genome_elitism,
|
|
survival_threshold=self.config.neat.survival_threshold,
|
|
compatibility_threshold=self.config.neat.compatibility_threshold,
|
|
compatibility_disjoint=self.config.neat.compatibility_disjoint,
|
|
compatibility_weight=self.config.neat.compatibility_weight,
|
|
|
|
input_idx=input_idx,
|
|
output_idx=output_idx,
|
|
)
|
|
|
|
state = self.gene.setup(state)
|
|
pop_genomes = self._initialize_genomes(state)
|
|
|
|
species_info = SpeciesInfo.initialize(state)
|
|
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
|
|
|
|
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
|
|
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
|
|
center_genomes = Genome(center_nodes, center_conns)
|
|
center_genomes = center_genomes.set(0, pop_genomes[0])
|
|
|
|
generation = 0
|
|
next_node_key = max(*state.input_idx, *state.output_idx) + 2
|
|
next_species_key = 1
|
|
|
|
state = state.update(
|
|
randkey=randkey,
|
|
pop_genomes=pop_genomes,
|
|
species_info=species_info,
|
|
idx2species=idx2species,
|
|
center_genomes=center_genomes,
|
|
|
|
# avoid jax auto cast from int to float. that would cause re-compilation.
|
|
generation=jnp.asarray(generation, dtype=jnp.int32),
|
|
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
|
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
|
|
)
|
|
|
|
return jax.device_put(state)
|
|
|
|
def ask_algorithm(self, state: State):
|
|
return state.pop_genomes
|
|
|
|
def tell_algorithm(self, state: State, fitness):
|
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
|
|
|
state = state.update(
|
|
generation=state.generation + 1,
|
|
randkey=randkey
|
|
)
|
|
|
|
state, winner, loser, elite_mask = update_species(state, k1, fitness)
|
|
|
|
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
|
|
|
|
state = speciate(self.gene, state)
|
|
|
|
return state
|
|
|
|
def forward_transform(self, state: State, genome: Genome):
|
|
return self.gene.forward_transform(state, genome)
|
|
|
|
def forward(self, state: State, inputs, genome: Genome):
|
|
return self.gene.forward(state, inputs, genome)
|
|
|
|
def _initialize_genomes(self, state):
|
|
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
|
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
|
|
|
input_idx = state.input_idx
|
|
output_idx = state.output_idx
|
|
new_node_key = max([*input_idx, *output_idx]) + 1
|
|
|
|
o_nodes[input_idx, 0] = input_idx
|
|
o_nodes[output_idx, 0] = output_idx
|
|
o_nodes[new_node_key, 0] = new_node_key
|
|
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
|
|
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
|
|
|
|
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
|
|
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
|
o_conns[input_idx, 2] = True # enabled
|
|
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
|
|
|
|
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
|
|
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
|
o_conns[output_idx, 2] = True # enabled
|
|
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
|
|
|
|
# repeat origin genome for P times to create population
|
|
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
|
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
|
|
|
return Genome(pop_nodes, pop_conns)
|