new architecture
This commit is contained in:
@@ -1,87 +1,38 @@
|
||||
from typing import Type
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State
|
||||
from .. import BaseAlgorithm
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .ga import *
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
class NEAT(BaseAlgorithm):
|
||||
|
||||
from config import Config
|
||||
from core import Algorithm, State, Gene, Genome
|
||||
from .ga import create_next_generation
|
||||
from .species import SpeciesInfo, update_species, speciate
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
):
|
||||
self.genome = genome
|
||||
self.species = species
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
|
||||
|
||||
class NEAT(Algorithm):
|
||||
|
||||
def __init__(self, config: Config, gene_type: Type[Gene]):
|
||||
self.config = config
|
||||
self.gene = gene_type(config.gene)
|
||||
|
||||
self.forward_func = None
|
||||
self.tell_func = None
|
||||
|
||||
def setup(self, randkey, state: State = State()):
|
||||
"""initialize the state of the algorithm"""
|
||||
|
||||
input_idx = np.arange(self.config.neat.inputs)
|
||||
output_idx = np.arange(self.config.neat.inputs,
|
||||
self.config.neat.inputs + self.config.neat.outputs)
|
||||
|
||||
state = state.update(
|
||||
P=self.config.basic.pop_size,
|
||||
N=self.config.neat.max_nodes,
|
||||
C=self.config.neat.max_conns,
|
||||
S=self.config.neat.max_species,
|
||||
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
max_stagnation=self.config.neat.max_stagnation,
|
||||
species_elitism=self.config.neat.species_elitism,
|
||||
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
|
||||
genome_elitism=self.config.neat.genome_elitism,
|
||||
survival_threshold=self.config.neat.survival_threshold,
|
||||
compatibility_threshold=self.config.neat.compatibility_threshold,
|
||||
compatibility_disjoint=self.config.neat.compatibility_disjoint,
|
||||
compatibility_weight=self.config.neat.compatibility_weight,
|
||||
|
||||
input_idx=input_idx,
|
||||
output_idx=output_idx,
|
||||
def setup(self, randkey):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
return State(
|
||||
randkey=k1,
|
||||
generation=0,
|
||||
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
||||
# inputs nodes, output nodes, 1 hidden node
|
||||
species=self.species.setup(k2),
|
||||
)
|
||||
|
||||
state = self.gene.setup(state)
|
||||
pop_genomes = self._initialize_genomes(state)
|
||||
|
||||
species_info = SpeciesInfo.initialize(state)
|
||||
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
|
||||
|
||||
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
|
||||
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
|
||||
center_genomes = Genome(center_nodes, center_conns)
|
||||
center_genomes = center_genomes.set(0, pop_genomes[0])
|
||||
|
||||
generation = 0
|
||||
next_node_key = max(*state.input_idx, *state.output_idx) + 2
|
||||
next_species_key = 1
|
||||
|
||||
state = state.update(
|
||||
randkey=randkey,
|
||||
pop_genomes=pop_genomes,
|
||||
species_info=species_info,
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
|
||||
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
|
||||
)
|
||||
|
||||
return jax.device_put(state)
|
||||
|
||||
def ask_algorithm(self, state: State):
|
||||
return state.pop_genomes
|
||||
|
||||
def tell_algorithm(self, state: State, fitness):
|
||||
state = self.gene.update(state)
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(
|
||||
@@ -89,46 +40,55 @@ class NEAT(Algorithm):
|
||||
randkey=randkey
|
||||
)
|
||||
|
||||
state, winner, loser, elite_mask = update_species(state, k1, fitness)
|
||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
|
||||
|
||||
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
|
||||
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
|
||||
|
||||
state = speciate(self.gene, state)
|
||||
state = self.species.speciate(state, state.generation)
|
||||
|
||||
return state
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
return self.gene.forward_transform(state, genome)
|
||||
def transform(self, state: State):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state: State, inputs, genome: Genome):
|
||||
return self.gene.forward(state, inputs, genome)
|
||||
def forward(self, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def _initialize_genomes(self, state):
|
||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
|
||||
# prepare random keys
|
||||
pop_size = self.species.pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
o_nodes[input_idx, 0] = input_idx
|
||||
o_nodes[output_idx, 0] = output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
|
||||
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
|
||||
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
|
||||
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
|
||||
# batch crossover
|
||||
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
|
||||
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
|
||||
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return state.update(
|
||||
species=state.species.update(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
),
|
||||
next_node_key=next_node_key,
|
||||
)
|
||||
|
||||
return Genome(pop_nodes, pop_conns)
|
||||
|
||||
Reference in New Issue
Block a user