complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm):
def __init__(
self,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome: BaseGenome = species.genome
self.species = species
self.mutation = mutation
self.crossover = crossover
self.genome = species.genome
def setup(self, state=State()):
state = self.species.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
state = state.register(
generation=jnp.array(0.0),
next_node_key=jnp.array(
@@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm):
return state
def ask(self, state: State):
return state, self.species.ask(state.species)
return self.species.ask(state)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.species.update_species(
state.species, fitness
)
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.species.speciate(state.species)
state = self.species.speciate(state)
return state
@@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm):
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
crossover_randkeys = jax.random.split(k1, pop_size)
mutate_randkeys = jax.random.split(k2, pop_size)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))(
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc
)
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))(
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys
)
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm):
)
def member_count(self, state: State):
return state, state.species.member_count
return state.member_count
def generation(self, state: State):
# to analysis the algorithm
return state, state.generation
return state.generation