use jit().lower.compile in create functions
This commit is contained in:
@@ -6,25 +6,31 @@ from numpy.typing import NDArray
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(config, type: str):
|
||||
def create_distance_function(N, config, type: str):
|
||||
"""
|
||||
:param N:
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
|
||||
def distance_with_args(nodes1, connections1, nodes2, connections2):
|
||||
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
||||
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
elif type == 'o2m':
|
||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
nodes1_lower = jnp.zeros((N, 5))
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user