finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
from functools import partial
|
||||
"""
|
||||
Crossover two genomes to generate a new genome.
|
||||
The calculation method is the same as the crossover operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
@@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
@@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
@@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
# @jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
Reference in New Issue
Block a user