remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,20 +1,18 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
from jax import numpy as jnp
import numpy as np
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate
from .species import SpeciesInfo, update_species, create_speciate
from .ga import create_next_generation
from .species import SpeciesInfo, update_species, speciate
class NEAT(Algorithm):
def __init__(self, config: Config, gene_type: Type[Gene]):
def __init__(self, config: Config, gene: Gene):
self.config = config
self.gene_type = gene_type
self.gene = gene
self.forward_func = None
self.tell_func = None
@@ -31,8 +29,8 @@ class NEAT(Algorithm):
N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species,
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
@@ -46,7 +44,7 @@ class NEAT(Algorithm):
output_idx=output_idx,
)
state = self.gene_type.setup(self.config.gene, state)
state = self.gene.setup(state)
pop_genomes = self._initialize_genomes(state)
species_info = SpeciesInfo.initialize(state)
@@ -74,26 +72,32 @@ class NEAT(Algorithm):
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
self.forward_func = self.gene_type.create_forward(state, self.config.gene)
self.tell_func = self._create_tell()
return jax.device_put(state)
def ask(self, state: State):
"""require the population to be evaluated"""
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
return self.tell_func(state, fitness)
def tell_algorithm(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
def forward(self, inputs: Array, transformed: Array):
"""the forward function of a single forward transformation"""
return self.forward_func(inputs, transformed)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
state = speciate(self.gene, state)
return state
def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome"""
return self.gene_type.forward_transform(state, genome)
return self.gene.forward_transform(state, genome)
def forward(self, state: State, inputs, genome: Genome):
return self.gene.forward(state, inputs, genome)
def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
@@ -106,80 +110,21 @@ class NEAT(Algorithm):
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state)
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state)
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state)
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return Genome(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)
def create_next_generation(state, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0))
m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)
speciate = create_speciate(self.gene_type)
def tell(state, fitness):
"""
Main update function in NEAT.
"""
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(state, k2, winner, loser, elite_mask)
state = speciate(state)
return state
return tell