move o2o_distance and o2m_distance to pipelines

This commit is contained in:
wls2002
2023-05-08 01:19:45 +08:00
parent c705b5cfe2
commit 497d89fc69
3 changed files with 55 additions and 23 deletions

View File

@@ -6,9 +6,9 @@ import jax.numpy as jnp
import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import create_crossover_function
from .genome import expand, expand_single
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
create_distance_function, create_crossover_function
class Pipeline:
@@ -30,9 +30,12 @@ class Pipeline:
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
self.generation, self.o2o_distance, self.o2m_distance)
self.best_fitness = float('-inf')
@@ -57,7 +60,8 @@ class Pipeline:
self.update_next_generation(crossover_pair)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
self.o2o_distance, self.o2m_distance)
self.expand()
@@ -119,8 +123,6 @@ class Pipeline:
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
def expand(self):
"""
Expand the population if needed.