remove create_func....
This commit is contained in:
@@ -1,20 +1,18 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, vmap
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from config import Config
|
||||
from core import Algorithm, State, Gene, Genome
|
||||
from .ga import crossover, create_mutate
|
||||
from .species import SpeciesInfo, update_species, create_speciate
|
||||
from .ga import create_next_generation
|
||||
from .species import SpeciesInfo, update_species, speciate
|
||||
|
||||
|
||||
class NEAT(Algorithm):
|
||||
|
||||
def __init__(self, config: Config, gene_type: Type[Gene]):
|
||||
def __init__(self, config: Config, gene: Gene):
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
self.gene = gene
|
||||
|
||||
self.forward_func = None
|
||||
self.tell_func = None
|
||||
@@ -31,8 +29,8 @@ class NEAT(Algorithm):
|
||||
N=self.config.neat.maximum_nodes,
|
||||
C=self.config.neat.maximum_conns,
|
||||
S=self.config.neat.maximum_species,
|
||||
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
max_stagnation=self.config.neat.max_stagnation,
|
||||
species_elitism=self.config.neat.species_elitism,
|
||||
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
|
||||
@@ -46,7 +44,7 @@ class NEAT(Algorithm):
|
||||
output_idx=output_idx,
|
||||
)
|
||||
|
||||
state = self.gene_type.setup(self.config.gene, state)
|
||||
state = self.gene.setup(state)
|
||||
pop_genomes = self._initialize_genomes(state)
|
||||
|
||||
species_info = SpeciesInfo.initialize(state)
|
||||
@@ -74,26 +72,32 @@ class NEAT(Algorithm):
|
||||
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
|
||||
)
|
||||
|
||||
self.forward_func = self.gene_type.create_forward(state, self.config.gene)
|
||||
self.tell_func = self._create_tell()
|
||||
|
||||
return jax.device_put(state)
|
||||
|
||||
def ask(self, state: State):
|
||||
"""require the population to be evaluated"""
|
||||
def ask_algorithm(self, state: State):
|
||||
return state.pop_genomes
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
"""update the state of the algorithm"""
|
||||
return self.tell_func(state, fitness)
|
||||
def tell_algorithm(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
def forward(self, inputs: Array, transformed: Array):
|
||||
"""the forward function of a single forward transformation"""
|
||||
return self.forward_func(inputs, transformed)
|
||||
state = state.update(
|
||||
generation=state.generation + 1,
|
||||
randkey=randkey
|
||||
)
|
||||
|
||||
state, winner, loser, elite_mask = update_species(state, k1, fitness)
|
||||
|
||||
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
|
||||
|
||||
state = speciate(self.gene, state)
|
||||
|
||||
return state
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
"""create the forward transformation of a genome"""
|
||||
return self.gene_type.forward_transform(state, genome)
|
||||
return self.gene.forward_transform(state, genome)
|
||||
|
||||
def forward(self, state: State, inputs, genome: Genome):
|
||||
return self.gene.forward(state, inputs, genome)
|
||||
|
||||
def _initialize_genomes(self, state):
|
||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||
@@ -106,80 +110,21 @@ class NEAT(Algorithm):
|
||||
o_nodes[input_idx, 0] = input_idx
|
||||
o_nodes[output_idx, 0] = output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state)
|
||||
o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state)
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
|
||||
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state)
|
||||
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state)
|
||||
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
||||
|
||||
return Genome(pop_nodes, pop_conns)
|
||||
|
||||
def _create_tell(self):
|
||||
mutate = create_mutate(self.config.neat, self.gene_type)
|
||||
|
||||
def create_next_generation(state, randkey, winner, loser, elite_mask):
|
||||
# prepare random keys
|
||||
pop_size = state.idx2species.shape[0]
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
|
||||
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
|
||||
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
|
||||
|
||||
# batch mutation
|
||||
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0))
|
||||
m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
|
||||
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return state.update(
|
||||
pop_genomes=Genome(pop_nodes, pop_conns),
|
||||
next_node_key=next_node_key,
|
||||
)
|
||||
|
||||
speciate = create_speciate(self.gene_type)
|
||||
|
||||
def tell(state, fitness):
|
||||
"""
|
||||
Main update function in NEAT.
|
||||
"""
|
||||
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(
|
||||
generation=state.generation + 1,
|
||||
randkey=randkey
|
||||
)
|
||||
|
||||
state, winner, loser, elite_mask = update_species(state, k1, fitness)
|
||||
|
||||
state = create_next_generation(state, k2, winner, loser, elite_mask)
|
||||
|
||||
state = speciate(state)
|
||||
|
||||
return state
|
||||
|
||||
return tell
|
||||
|
||||
Reference in New Issue
Block a user