complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
):
|
||||
self.genome: BaseGenome = species.genome
|
||||
self.species = species
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
self.genome = species.genome
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
state = state.register(
|
||||
generation=jnp.array(0.0),
|
||||
next_node_key=jnp.array(
|
||||
@@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm):
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return state, self.species.ask(state.species)
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
|
||||
state, winner, loser, elite_mask = self.species.update_species(
|
||||
state.species, fitness
|
||||
)
|
||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.species.speciate(state.species)
|
||||
state = self.species.speciate(state)
|
||||
|
||||
return state
|
||||
|
||||
@@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm):
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
crossover_randkeys = jax.random.split(k1, pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, pop_size)
|
||||
|
||||
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
|
||||
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))(
|
||||
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc
|
||||
)
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))(
|
||||
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys
|
||||
)
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
@@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm):
|
||||
)
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state, state.species.member_count
|
||||
return state.member_count
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state, state.generation
|
||||
return state.generation
|
||||
|
||||
Reference in New Issue
Block a user