finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -1,25 +1,23 @@
from functools import partial
from typing import Type
import jax
import jax, jax.numpy as jnp
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
from algorithm import BaseAlgorithm
from problem import BaseProblem
from utils import State
class Pipeline:
def __init__(
self,
algorithm: Algorithm,
problem: Problem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
pop_size: int = 100,
self,
algorithm: BaseAlgorithm,
problem: BaseProblem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -28,17 +26,18 @@ class Pipeline:
self.seed = seed
self.fitness_target = fitness_target
self.generation_limit = generation_limit
self.pop_size = pop_size
self.pop_size = self.algorithm.pop_size
print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num
assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
assert algorithm.num_inputs == self.problem.input_shape[-1], \
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
self.act_func = self.algorithm.act
# self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
# for _ in range(len(self.problem.input_shape) - 1):
# self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
@@ -46,41 +45,57 @@ class Pipeline:
def setup(self):
key = jax.random.PRNGKey(self.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
key, algorithm_key, evaluate_key = jax.random.split(key, 3)
# TODO: Problem should has setup function to maintain state
return State(
randkey=key,
alg=self.algorithm.setup(algorithm_key),
pro=self.problem.setup(evaluate_key),
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
key, sub_key = jax.random.split(state.randkey)
keys = jax.random.split(key, self.pop_size)
pop = self.algorithm.ask(state)
pop = self.algorithm.ask(state.alg)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
pop_transformed = jax.vmap(self.algorithm.transform)(pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(
keys,
state.pro,
self.algorithm.forward,
pop_transformed
)
state = self.algorithm.tell(state, fitnesses)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
alg_state = self.algorithm.tell(state.alg, fitnesses)
return state.update(
randkey=sub_key,
alg=alg_state,
), fitnesses
def auto_run(self, ini_state):
state = ini_state
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
previous_pop = self.algorithm.ask(state.alg)
state, fitnesses = self.step(state)
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)
@@ -102,22 +117,15 @@ class Pipeline:
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
member_count = jax.device_get(state.species_info.member_count)
member_count = jax.device_get(self.algorithm.member_count(state.alg))
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
print(f"Generation: {self.algorithm.generation(state.alg)}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome, *args, **kwargs):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(best)
self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)