from functools import partial from typing import Type import jax import time import numpy as np from algorithm import NEAT, HyperNEAT from config import Config from core import State, Algorithm, Problem class Pipeline: def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]): self.config = config self.algorithm = algorithm self.problem = problem_type(config.problem) if isinstance(algorithm, NEAT): assert config.neat.inputs == self.problem.input_shape[-1] elif isinstance(algorithm, HyperNEAT): assert config.hyperneat.inputs == self.problem.input_shape[-1] else: raise NotImplementedError self.act_func = self.algorithm.act for _ in range(len(self.problem.input_shape) - 1): self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) self.best_genome = None self.best_fitness = float('-inf') self.generation_timestamp = None def setup(self): key = jax.random.PRNGKey(self.config.basic.seed) algorithm_key, evaluate_key = jax.random.split(key, 2) state = State() state = self.algorithm.setup(algorithm_key, state) return state.update( evaluate_key=evaluate_key ) @partial(jax.jit, static_argnums=(0,)) def step(self, state): key, sub_key = jax.random.split(state.evaluate_key) keys = jax.random.split(key, self.config.basic.pop_size) pop = self.algorithm.ask(state) pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, pop_transformed) state = self.algorithm.tell(state, fitnesses) return state.update(evaluate_key=sub_key), fitnesses def auto_run(self, ini_state): state = ini_state for _ in range(self.config.basic.generation_limit): self.generation_timestamp = time.time() previous_pop = self.algorithm.ask(state) state, fitnesses = self.step(state) fitnesses = jax.device_get(fitnesses) self.analysis(state, previous_pop, fitnesses) if max(fitnesses) >= self.config.basic.fitness_target: print("Fitness limit reached!") return state, self.best_genome print("Generation limit reached!") return state, self.best_genome def analysis(self, state, pop, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) new_timestamp = time.time() cost_time = new_timestamp - self.generation_timestamp max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] self.best_genome = pop[max_idx] member_count = jax.device_get(state.species_info.member_count) species_sizes = [int(i) for i in member_count if i > 0] print(f"Generation: {state.generation}", f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") def show(self, state, genome): transformed = self.algorithm.transform(state, genome) self.problem.show(state.evaluate_key, state, self.act_func, transformed) def pre_compile(self, state): tic = time.time() print("start compile") self.step.lower(self, state).compile() # compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile() # self.__dict__['step'] = compiled_step print(f"compile finished, cost time: {time.time() - tic}s")