clean imports and delete "create_XXX_functions"
This commit is contained in:
@@ -8,38 +8,7 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
def create_crossover_function(N, config, batch: bool, debug: bool = False):
|
||||
if batch:
|
||||
pop_size = config.neat.population.pop_size
|
||||
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||
nodes1_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections1_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
|
||||
res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
if debug:
|
||||
return lambda *args: res_func(*args)
|
||||
else:
|
||||
return res_func
|
||||
|
||||
else:
|
||||
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
|
||||
nodes1_lower = jnp.zeros((N, 5))
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((N, 5))
|
||||
connections2_lower = jnp.zeros((2, N, N))
|
||||
|
||||
res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
if debug:
|
||||
return lambda *args: res_func(*args)
|
||||
else:
|
||||
return res_func
|
||||
|
||||
|
||||
# @jit
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -70,7 +39,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
make ar2 align with ar1.
|
||||
@@ -97,7 +66,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
# @jit
|
||||
@jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
Reference in New Issue
Block a user