Files
tensorneat-mend/pipeline.py
2023-10-22 21:01:06 +08:00

123 lines
4.0 KiB
Python

"""
pipeline for jitable env like func_fit, gymnax
"""
from functools import partial
from typing import Type
import jax
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
class Pipeline:
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
assert problem_type.jitable, "problem must be jitable"
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
print(self.problem.input_shape, self.problem.output_shape)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
pop = self.algorithm.ask(state)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome, *args, **kwargs):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")