create function_factory.py, use to manage functions
This commit is contained in:
@@ -2,13 +2,13 @@ from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
|
||||
create_distance_function, create_crossover_function
|
||||
from .function_factory import FunctionFactory
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -17,17 +17,18 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.generation_timestamp = time.time()
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config
|
||||
self.function_factory = FunctionFactory(config)
|
||||
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = create_initialize_function(config)
|
||||
self.initialize_func = self.function_factory.create_initialize()
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
|
||||
self.compile_functions(debug=True)
|
||||
@@ -36,6 +37,7 @@ class Pipeline:
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
def ask(self, batch: bool):
|
||||
"""
|
||||
@@ -140,10 +142,9 @@ class Pipeline:
|
||||
self.compile_functions(debug=True)
|
||||
|
||||
def compile_functions(self, debug=False):
|
||||
self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug)
|
||||
self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug)
|
||||
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug)
|
||||
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug)
|
||||
self.mutate_func = self.function_factory.create_mutate(self.N)
|
||||
self.crossover_func = self.function_factory.create_crossover(self.N)
|
||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N)
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
Reference in New Issue
Block a user