add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -5,7 +5,7 @@ import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
from .genome import create_distance_function
class Species(object):
@@ -47,8 +47,8 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
@@ -106,7 +106,7 @@ class SpeciesController:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)