change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import jax, jax.numpy as jnp
|
|
from utils import State
|
|
from .. import BaseAlgorithm
|
|
from .species import *
|
|
from .ga import *
|
|
from .genome import *
|
|
|
|
|
|
class NEAT(BaseAlgorithm):
|
|
def __init__(
|
|
self,
|
|
species: BaseSpecies,
|
|
):
|
|
self.species = species
|
|
self.genome = species.genome
|
|
|
|
def setup(self, state=State()):
|
|
state = self.species.setup(state)
|
|
state = state.register(
|
|
generation=jnp.array(0.0),
|
|
next_node_key=jnp.array(
|
|
max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
|
dtype=jnp.float32,
|
|
),
|
|
)
|
|
return state
|
|
|
|
def ask(self, state: State):
|
|
return self.species.ask(state)
|
|
|
|
def tell(self, state: State, fitness):
|
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
|
|
|
state = state.update(generation=state.generation + 1, randkey=randkey)
|
|
|
|
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
|
|
state = self.create_next_generation(state, winner, loser, elite_mask)
|
|
state = self.species.speciate(state)
|
|
|
|
return state
|
|
|
|
def transform(self, state, individual):
|
|
"""transform the genome into a neural network"""
|
|
nodes, conns = individual
|
|
return self.genome.transform(state, nodes, conns)
|
|
|
|
def restore(self, state, transformed):
|
|
return self.genome.restore(state, transformed)
|
|
|
|
def forward(self, state, transformed, inputs):
|
|
return self.genome.forward(state, transformed, inputs)
|
|
|
|
def update_by_batch(self, state, batch_input, transformed):
|
|
return self.genome.update_by_batch(state, batch_input, transformed)
|
|
|
|
@property
|
|
def num_inputs(self):
|
|
return self.genome.num_inputs
|
|
|
|
@property
|
|
def num_outputs(self):
|
|
return self.genome.num_outputs
|
|
|
|
@property
|
|
def pop_size(self):
|
|
return self.species.pop_size
|
|
|
|
def create_next_generation(self, state, winner, loser, elite_mask):
|
|
# prepare random keys
|
|
pop_size = self.species.pop_size
|
|
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
|
|
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
|
crossover_randkeys = jax.random.split(k1, pop_size)
|
|
mutate_randkeys = jax.random.split(k2, pop_size)
|
|
|
|
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
|
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
|
|
|
# batch crossover
|
|
n_nodes, n_conns = jax.vmap(
|
|
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
|
)(
|
|
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
|
) # new_nodes, new_conns
|
|
|
|
# batch mutation
|
|
m_n_nodes, m_n_conns = jax.vmap(
|
|
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
|
)(
|
|
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
|
) # mutated_new_nodes, mutated_new_conns
|
|
|
|
# elitism don't mutate
|
|
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
|
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
|
|
|
# update next node key
|
|
all_nodes_keys = pop_nodes[:, :, 0]
|
|
max_node_key = jnp.max(
|
|
jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)
|
|
)
|
|
next_node_key = max_node_key + 1
|
|
|
|
return state.update(
|
|
randkey=randkey,
|
|
pop_nodes=pop_nodes,
|
|
pop_conns=pop_conns,
|
|
next_node_key=next_node_key,
|
|
)
|
|
|
|
def member_count(self, state: State):
|
|
return state.member_count
|
|
|
|
def generation(self, state: State):
|
|
# to analysis the algorithm
|
|
return state.generation
|