use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -3,58 +3,57 @@ from utils import State
|
||||
from .. import BaseAlgorithm
|
||||
from .species import *
|
||||
from .ga import *
|
||||
from .genome import *
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
):
|
||||
self.genome = species.genome
|
||||
self.genome: BaseGenome = species.genome
|
||||
self.species = species
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
|
||||
def setup(self, randkey):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
return State(
|
||||
randkey=k1,
|
||||
generation=jnp.array(0.),
|
||||
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
|
||||
# inputs nodes, output nodes, 1 hidden node
|
||||
species=self.species.setup(k2),
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
state = state.register(
|
||||
generation=jnp.array(0.0),
|
||||
next_node_key=jnp.array(
|
||||
max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
||||
dtype=jnp.float32,
|
||||
),
|
||||
)
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state.species)
|
||||
return state, self.species.ask(state.species)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(
|
||||
generation=state.generation + 1,
|
||||
randkey=randkey
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
|
||||
state, winner, loser, elite_mask = self.species.update_species(
|
||||
state.species, fitness
|
||||
)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.species.speciate(state.species)
|
||||
|
||||
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
|
||||
state = state.update(species=species_state)
|
||||
|
||||
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
|
||||
|
||||
species_state = self.species.speciate(state.species, state.generation)
|
||||
state = state.update(species=species_state)
|
||||
return state
|
||||
|
||||
def transform(self, individual):
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(nodes, conns)
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
return self.genome.forward(inputs, transformed)
|
||||
def forward(self, state, inputs, transformed):
|
||||
return self.genome.forward(state, inputs, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
@@ -68,12 +67,12 @@ class NEAT(BaseAlgorithm):
|
||||
def pop_size(self):
|
||||
return self.species.pop_size
|
||||
|
||||
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
|
||||
def create_next_generation(self, state, winner, loser, elite_mask):
|
||||
# prepare random keys
|
||||
pop_size = self.species.pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
@@ -81,12 +80,14 @@ class NEAT(BaseAlgorithm):
|
||||
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
|
||||
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
||||
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))(
|
||||
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc
|
||||
)
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
|
||||
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
||||
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))(
|
||||
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys
|
||||
)
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
@@ -94,20 +95,21 @@ class NEAT(BaseAlgorithm):
|
||||
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
max_node_key = jnp.max(
|
||||
jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return state.update(
|
||||
species=state.species.update(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
),
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
next_node_key=next_node_key,
|
||||
)
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state.species.member_count
|
||||
return state, state.species.member_count
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state.generation
|
||||
return state, state.generation
|
||||
|
||||
Reference in New Issue
Block a user