refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

faster, faster and faster!
This commit is contained in:
wls2002
2023-05-12 00:57:55 +08:00
parent e5fc1167d9
commit 47b1a1dbb2
16 changed files with 363 additions and 419 deletions

View File

@@ -7,16 +7,16 @@ from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param connections1:
:param cons1:
:param nodes2:
:param connections2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
@@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type'])
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
@@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2
@jit
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes