add debug mode for create_xx_functions for detail time cost analysis

This commit is contained in:
wls2002
2023-05-08 15:42:25 +08:00
parent d4a75b9394
commit e201d03157
8 changed files with 70 additions and 38 deletions

View File

@@ -8,7 +8,7 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(N, config, batch: bool):
def create_crossover_function(N, config, batch: bool, debug: bool = False):
if batch:
pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
@@ -16,16 +16,27 @@ def create_crossover_function(N, config, batch: bool):
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
# @jit