finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -1,20 +1,19 @@
import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .genome import *
from .species import *
from .ga import *
class NEAT(BaseAlgorithm):
def __init__(
self,
genome: BaseGenome,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome = genome
self.genome = species.genome
self.species = species
self.mutation = mutation
self.crossover = crossover
@@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm):
k1, k2 = jax.random.split(randkey, 2)
return State(
randkey=k1,
generation=0,
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
generation=jnp.array(0.),
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
# inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2),
)
def ask(self, state: State):
return self.species.ask(state)
return self.species.ask(state.species)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
@@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm):
randkey=randkey
)
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
state = state.update(species=species_state)
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
state = self.species.speciate(state, state.generation)
species_state = self.species.speciate(state.species, state.generation)
state = state.update(species=species_state)
return state
def transform(self, state: State):
def transform(self, individual):
"""transform the genome into a neural network"""
raise NotImplementedError
nodes, conns = individual
return self.genome.transform(nodes, conns)
def forward(self, inputs, transformed):
raise NotImplementedError
return self.genome.forward(inputs, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
@property
def pop_size(self):
return self.species.pop_size
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm):
# batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
# batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm):
next_node_key=next_node_key,
)
def member_count(self, state: State):
return state.species.member_count
def generation(self, state: State):
# to analysis the algorithm
return state.generation