refactor genome.py use (C, 4) to replace (2, N, N) to represent connections
faster, faster and faster!
This commit is contained in:
@@ -7,16 +7,16 @@ from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
:param randkey:
|
||||
:param nodes1:
|
||||
:param connections1:
|
||||
:param cons1:
|
||||
:param nodes2:
|
||||
:param connections2:
|
||||
:param cons2:
|
||||
:return:
|
||||
"""
|
||||
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||
@@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
|
||||
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
@@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
Reference in New Issue
Block a user