update to test in servers

This commit is contained in:
wls2002
2023-05-10 22:33:51 +08:00
parent ce35b01896
commit b271a56827
9 changed files with 112 additions and 34 deletions

View File

@@ -7,7 +7,6 @@ import numpy as np
from .species import SpeciesController
from .genome import expand, expand_single
from .function_factory import FunctionFactory
from examples.time_utils import using_cprofile
class Pipeline:
@@ -16,7 +15,9 @@ class Pipeline:
"""
def __init__(self, config, seed=42):
self.time_dict = {}
self.function_factory = FunctionFactory(config, debug=True)
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
@@ -35,6 +36,7 @@ class Pipeline:
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
def ask(self):
@@ -43,7 +45,7 @@ class Pipeline:
:return:
Algorithm gives the population a forward function, then environment gives back the fitnesses.
"""
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections)
def tell(self, fitnesses):
@@ -72,10 +74,14 @@ class Pipeline:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config.neat.population.fitness_threshold:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
# @using_cprofile
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
"""
create the next generation
@@ -152,5 +158,10 @@ class Pipeline:
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")