use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -6,25 +6,31 @@ from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str):
def create_distance_function(N, config, type: str):
"""
:param N:
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
def distance_with_args(nodes1, connections1, nodes2, connections2):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
# return lambda nodes1, connections1, nodes2, connections2: \
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
pop_size = config.neat.population.pop_size
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')