complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -4,12 +4,12 @@ import jax
from jax import jit, Array, numpy as jnp
def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array):
def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Array):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key= jax.random.split(state.randkey, 3)
randkey_1, randkey_2, key= jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -21,11 +21,11 @@ def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array):
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, True)
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
cons2 = align_array(con_keys1, con_keys2, conns2, True)
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(cons2), conns1, crossover_gene(randkey_2, conns1, cons2))
return state.update(randkey=key), new_nodes, new_cons
return new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array: