complete normal neat algorithm
This commit is contained in:
@@ -4,12 +4,12 @@ import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
|
||||
def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array):
|
||||
def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Array):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey_1, randkey_2, key= jax.random.split(state.randkey, 3)
|
||||
randkey_1, randkey_2, key= jax.random.split(randkey, 3)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -21,11 +21,11 @@ def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array):
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, True)
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, conns2, True)
|
||||
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(cons2), conns1, crossover_gene(randkey_2, conns1, cons2))
|
||||
|
||||
return state.update(randkey=key), new_nodes, new_cons
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:
|
||||
|
||||
Reference in New Issue
Block a user