refactor genome.py use (C, 4) to replace (2, N, N) to represent connections
use "cons" to replace "connections" in code for beauty
This commit is contained in:
@@ -5,8 +5,6 @@ import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
@@ -29,11 +27,9 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
|
||||
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||
|
||||
return new_nodes, new_cons
|
||||
@@ -42,6 +38,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
|
||||
Reference in New Issue
Block a user