Files
tensorneat-mend/algorithm/neat/neat.py
2023-08-02 15:02:08 +08:00

133 lines
5.1 KiB
Python

from typing import Type
import jax
from jax import numpy as jnp
import numpy as np
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import create_next_generation
from .species import SpeciesInfo, update_species, speciate
class NEAT(Algorithm):
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene = gene_type(config.gene)
self.forward_func = None
self.tell_func = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species,
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
genome_elitism=self.config.neat.genome_elitism,
survival_threshold=self.config.neat.survival_threshold,
compatibility_threshold=self.config.neat.compatibility_threshold,
compatibility_disjoint=self.config.neat.compatibility_disjoint,
compatibility_weight=self.config.neat.compatibility_weight,
input_idx=input_idx,
output_idx=output_idx,
)
state = self.gene.setup(state)
pop_genomes = self._initialize_genomes(state)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
next_species_key = 1
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
return jax.device_put(state)
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
state = speciate(self.gene, state)
return state
def forward_transform(self, state: State, genome: Genome):
return self.gene.forward_transform(state, genome)
def forward(self, state: State, inputs, genome: Genome):
return self.gene.forward(state, inputs, genome)
def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
input_idx = state.input_idx
output_idx = state.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return Genome(pop_nodes, pop_conns)