remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

@@ -1,83 +1,115 @@
import time
from typing import Union, Callable
from functools import partial
from typing import Type
import jax
from jax import vmap, jit
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import Algorithm, Genome
from core import State, Algorithm, Problem
class Pipeline:
"""
Simple pipeline.
"""
def __init__(self, config: Config, algorithm: Algorithm):
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
randkey = jax.random.PRNGKey(config.basic.seed)
self.state = algorithm.setup(randkey)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1]
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1]
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
self.generation_timestamp = None
self.evaluate_time = 0
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
self.act_func = jit(self.algorithm.act)
self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
def tell(self, fitness):
# self.state = self.tell_func(self.state, fitness)
new_state = self.tell_func(self.state, fitness)
self.state = new_state
pop = self.algorithm.ask(state)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
self.generation_timestamp = time.time()
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return self.best_genome
return state, self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
self.best_genome = pop[max_idx]
member_count = jax.device_get(self.state.species_info.member_count)
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
# compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile()
# self.__dict__['step'] = compiled_step
print(f"compile finished, cost time: {time.time() - tic}s")