add debug mode for create_xx_functions for detail time cost analysis

This commit is contained in:
wls2002
2023-05-08 15:42:25 +08:00
parent d4a75b9394
commit e201d03157
8 changed files with 70 additions and 38 deletions

View File

@@ -6,11 +6,12 @@ from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(N, config, type: str):
def create_distance_function(N, config, type: str, debug: bool = False):
"""
:param N:
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:param debug:
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
@@ -20,8 +21,20 @@ def create_distance_function(N, config, type: str):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
# return lambda nodes1, connections1, nodes2, connections2: \
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
@@ -30,7 +43,12 @@ def create_distance_function(N, config, type: str):
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
@@ -48,6 +66,7 @@ def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
:param compatibility_coe:
:return:
"""
def analysis(nodes, connections):
nodes_dict = {}
idx2key = {}