add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -4,21 +4,41 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(config, type: str):
|
||||
"""
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
elif type == 'o2m':
|
||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||
compatibility_coe: float = 0.5) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||
"""
|
||||
|
||||
nd = node_distance(nodes1, nodes2) # node distance
|
||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||
|
||||
# refactor connections
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
cd = connection_distance(cons1, cons2) # connection distance
|
||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user