move o2o_distance and o2m_distance to pipelines
This commit is contained in:
@@ -6,9 +6,9 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import create_crossover_function
|
||||
from .genome import expand, expand_single
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
|
||||
create_distance_function, create_crossover_function
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -30,9 +30,12 @@ class Pipeline:
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||
self.crossover_func = create_crossover_function(batch=True)
|
||||
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
||||
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
||||
|
||||
self.generation = 0
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
|
||||
self.generation, self.o2o_distance, self.o2m_distance)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
|
||||
@@ -57,7 +60,8 @@ class Pipeline:
|
||||
|
||||
self.update_next_generation(crossover_pair)
|
||||
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
||||
self.o2o_distance, self.o2m_distance)
|
||||
|
||||
self.expand()
|
||||
|
||||
@@ -119,8 +123,6 @@ class Pipeline:
|
||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||
|
||||
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
|
||||
Reference in New Issue
Block a user