use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -8,29 +8,27 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool):
def create_crossover_function(N, config, batch: bool):
if batch:
return batch_crossover
pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((pop_size, N, 5))
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
else:
return crossover
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
"""
crossover a batch of genomes
:param randkeys: batches of random keys
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
@jit
# @jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]:
"""
@@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type'])
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
make ar2 align with ar1.
@@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2
@jit
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes