from functools import partial import jax, jax.numpy as jnp import time import numpy as np from algorithm import BaseAlgorithm from problem import BaseProblem from utils import State class Pipeline: def __init__( self, algorithm: BaseAlgorithm, problem: BaseProblem, seed: int = 42, fitness_target: float = 1, generation_limit: int = 1000, ): assert problem.jitable, "Currently, problem must be jitable" self.algorithm = algorithm self.problem = problem self.seed = seed self.fitness_target = fitness_target self.generation_limit = generation_limit self.pop_size = self.algorithm.pop_size print(self.problem.input_shape, self.problem.output_shape) # TODO: make each algorithm's input_num and output_num assert algorithm.num_inputs == self.problem.input_shape[-1], \ f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" # self.act_func = self.algorithm.act # for _ in range(len(self.problem.input_shape) - 1): # self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) self.best_genome = None self.best_fitness = float('-inf') self.generation_timestamp = None def setup(self): key = jax.random.PRNGKey(self.seed) key, algorithm_key, evaluate_key = jax.random.split(key, 3) # TODO: Problem should has setup function to maintain state return State( randkey=key, alg=self.algorithm.setup(algorithm_key), pro=self.problem.setup(evaluate_key), ) def step(self, state): key, sub_key = jax.random.split(state.randkey) keys = jax.random.split(key, self.pop_size) pop = self.algorithm.ask(state.alg) pop_transformed = jax.vmap(self.algorithm.transform)(pop) fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))( keys, state.pro, self.algorithm.forward, pop_transformed ) fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) alg_state = self.algorithm.tell(state.alg, fitnesses) return state.update( randkey=sub_key, alg=alg_state, ), fitnesses def auto_run(self, ini_state): state = ini_state compiled_step = jax.jit(self.step).lower(ini_state).compile() for w in range(self.generation_limit): self.generation_timestamp = time.time() previous_pop = self.algorithm.ask(state.alg) state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) for idx, fitnesses_i in enumerate(fitnesses): if np.isnan(fitnesses_i): print("Fitness is nan") print(previous_pop[0][idx], previous_pop[1][idx]) assert False self.analysis(state, previous_pop, fitnesses) if max(fitnesses) >= self.fitness_target: print("Fitness limit reached!") return state, self.best_genome node= previous_pop[0][0][:,0] node_count = jnp.sum(~jnp.isnan(node)) conn= previous_pop[1][0][:,0] conn_count = jnp.sum(~jnp.isnan(conn)) if(w%5==0): print("node_count",node_count) print("conn_count",conn_count) print("Generation limit reached!") return state, self.best_genome def analysis(self, state, pop, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) new_timestamp = time.time() cost_time = new_timestamp - self.generation_timestamp max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx] member_count = jax.device_get(self.algorithm.member_count(state.alg)) species_sizes = [int(i) for i in member_count if i > 0] print(f"Generation: {self.algorithm.generation(state.alg)}", f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") def show(self, state, best, *args, **kwargs): transformed = self.algorithm.transform(best) self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)