add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -5,7 +5,7 @@ import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome import distance
|
||||
from .genome import create_distance_function
|
||||
|
||||
|
||||
class Species(object):
|
||||
@@ -47,8 +47,8 @@ class SpeciesController:
|
||||
self.species_idxer = count(0)
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
self.distance = distance
|
||||
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
|
||||
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
||||
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
||||
|
||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
||||
"""
|
||||
@@ -106,7 +106,7 @@ class SpeciesController:
|
||||
# the representatives of new species
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
distances = [
|
||||
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
]
|
||||
distances = np.array(distances)
|
||||
|
||||
Reference in New Issue
Block a user