41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
import jax
|
|
from jax import numpy as jnp, vmap
|
|
|
|
from config import NeatConfig
|
|
from core import Genome, State, Gene
|
|
from .mutate import mutate
|
|
from .crossover import crossover
|
|
|
|
|
|
def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask):
|
|
# prepare random keys
|
|
pop_size = state.idx2species.shape[0]
|
|
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
|
|
|
k1, k2 = jax.random.split(randkey, 2)
|
|
crossover_rand_keys = jax.random.split(k1, pop_size)
|
|
mutate_rand_keys = jax.random.split(k2, pop_size)
|
|
|
|
# batch crossover
|
|
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
|
|
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
|
|
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
|
|
|
|
# batch mutation
|
|
mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0))
|
|
m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
|
|
|
|
# elitism don't mutate
|
|
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
|
|
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
|
|
|
|
# update next node key
|
|
all_nodes_keys = pop_nodes[:, :, 0]
|
|
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
|
next_node_key = max_node_key + 1
|
|
|
|
return state.update(
|
|
pop_genomes=Genome(pop_nodes, pop_conns),
|
|
next_node_key=next_node_key,
|
|
)
|