add debug mode for create_xx_functions for detail time cost analysis
This commit is contained in:
@@ -6,11 +6,12 @@ from numpy.typing import NDArray
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(N, config, type: str):
|
||||
def create_distance_function(N, config, type: str, debug: bool = False):
|
||||
"""
|
||||
:param N:
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:param debug:
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
@@ -20,8 +21,20 @@ def create_distance_function(N, config, type: str):
|
||||
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
nodes1_lower = jnp.zeros((N, 5))
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((N, 5))
|
||||
connections2_lower = jnp.zeros((2, N, N))
|
||||
|
||||
res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
if debug:
|
||||
return lambda *args: res_func(*args) # for debug
|
||||
else:
|
||||
return res_func
|
||||
|
||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
||||
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
elif type == 'o2m':
|
||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
@@ -30,7 +43,12 @@ def create_distance_function(N, config, type: str):
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
||||
res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
||||
if debug:
|
||||
return lambda *args: res_func(*args) # for debug
|
||||
else:
|
||||
return res_func
|
||||
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
@@ -48,6 +66,7 @@ def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
|
||||
:param compatibility_coe:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def analysis(nodes, connections):
|
||||
nodes_dict = {}
|
||||
idx2key = {}
|
||||
|
||||
Reference in New Issue
Block a user