From 3b6fe7eadc3c5c9cb3d93b418243765665700e36 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 9 Aug 2023 18:01:21 +0800 Subject: [PATCH] add more rl task in examples --- algorithm/hyperneat/hyperneat.py | 1 - core/problem.py | 2 + examples/func_fit/xor.py | 2 +- examples/func_fit/xor_hyperneat.py | 2 +- examples/func_fit/xor_recurrent.py | 2 +- examples/gymnax/acrobot.py | 39 +++++++ examples/gymnax/cartpole.py | 2 +- examples/gymnax/mountain_car.py | 39 +++++++ examples/gymnax/mountain_car_continuous.py | 38 +++++++ examples/gymnax/pendulum.py | 39 +++++++ examples/gymnax/reacher.py | 36 +++++++ pipeline_jitable_env.py | 120 +++++++++++++++++++++ pipeline.py => pipeline_time.py | 4 - problem/func_fit/func_fit.py | 2 + problem/rl_env/gym_env.py | 40 +++++++ problem/rl_env/gymnax_env.py | 4 +- problem/rl_env/{rl_env.py => rl_jit.py} | 2 + problem/rl_env/rl_unjit.py | 69 ++++++++++++ 18 files changed, 431 insertions(+), 12 deletions(-) create mode 100644 examples/gymnax/acrobot.py create mode 100644 examples/gymnax/mountain_car.py create mode 100644 examples/gymnax/mountain_car_continuous.py create mode 100644 examples/gymnax/pendulum.py create mode 100644 examples/gymnax/reacher.py create mode 100644 pipeline_jitable_env.py rename pipeline.py => pipeline_time.py (94%) create mode 100644 problem/rl_env/gym_env.py rename problem/rl_env/{rl_env.py => rl_jit.py} (99%) create mode 100644 problem/rl_env/rl_unjit.py diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py index 902c8c0..9bd21ae 100644 --- a/algorithm/hyperneat/hyperneat.py +++ b/algorithm/hyperneat/hyperneat.py @@ -6,7 +6,6 @@ import numpy as np from config import Config, HyperNeatConfig from core import Algorithm, Substrate, State, Genome, Gene -from utils import Act, Agg from .substrate import analysis_substrate from algorithm import NEAT diff --git a/core/problem.py b/core/problem.py index c70c409..87d4396 100644 --- a/core/problem.py +++ b/core/problem.py @@ -6,6 +6,8 @@ from .state import State class Problem: + jitable: bool + def __init__(self, problem_config: ProblemConfig = ProblemConfig()): self.config = problem_config diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index a2d45ee..8e5ca6c 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -1,5 +1,5 @@ from config import * -from pipeline import Pipeline +from pipeline_jitable_env import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.func_fit import XOR, FuncFitConfig diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py index cfd23f1..0148e28 100644 --- a/examples/func_fit/xor_hyperneat.py +++ b/examples/func_fit/xor_hyperneat.py @@ -1,5 +1,5 @@ from config import * -from pipeline import Pipeline +from pipeline_jitable_env import Pipeline from algorithm.neat import NormalGene, NormalGeneConfig from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig from problem.func_fit import XOR3d, FuncFitConfig diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py index d100fd8..6787a7d 100644 --- a/examples/func_fit/xor_recurrent.py +++ b/examples/func_fit/xor_recurrent.py @@ -1,5 +1,5 @@ from config import * -from pipeline import Pipeline +from pipeline_jitable_env import Pipeline from algorithm import NEAT from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig from problem.func_fit import XOR3d, FuncFitConfig diff --git a/examples/gymnax/acrobot.py b/examples/gymnax/acrobot.py new file mode 100644 index 0000000..0f6cdd0 --- /dev/null +++ b/examples/gymnax/acrobot.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp + +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=0, + pop_size=10000 + ), + neat=NeatConfig( + inputs=6, + outputs=3, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=GymNaxConfig( + env_name='Acrobot-v1', + output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2} + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index bc18d86..d5b3c84 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline import Pipeline +from pipeline_jitable_env import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/mountain_car.py b/examples/gymnax/mountain_car.py new file mode 100644 index 0000000..7a897bb --- /dev/null +++ b/examples/gymnax/mountain_car.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp + +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=0, + pop_size=10000 + ), + neat=NeatConfig( + inputs=2, + outputs=3, + ), + gene=NormalGeneConfig( + activation_default=Act.sigmoid, + activation_options=(Act.sigmoid,), + ), + problem=GymNaxConfig( + env_name='MountainCar-v0', + output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2} + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py new file mode 100644 index 0000000..d5b5a01 --- /dev/null +++ b/examples/gymnax/mountain_car_continuous.py @@ -0,0 +1,38 @@ +import jax.numpy as jnp + +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=100, + pop_size=10000 + ), + neat=NeatConfig( + inputs=2, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=GymNaxConfig( + env_name='MountainCarContinuous-v0' + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/pendulum.py b/examples/gymnax/pendulum.py new file mode 100644 index 0000000..c991f16 --- /dev/null +++ b/examples/gymnax/pendulum.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp + +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=0, + pop_size=10000 + ), + neat=NeatConfig( + inputs=3, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=GymNaxConfig( + env_name='Pendulum-v1', + output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/reacher.py b/examples/gymnax/reacher.py new file mode 100644 index 0000000..08cf04c --- /dev/null +++ b/examples/gymnax/reacher.py @@ -0,0 +1,36 @@ +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=8, + outputs=2, + ), + gene=NormalGeneConfig( + activation_default=Act.sigmoid, + activation_options=(Act.sigmoid,), + ), + problem=GymNaxConfig( + env_name='Reacher-misc', + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/pipeline_jitable_env.py b/pipeline_jitable_env.py new file mode 100644 index 0000000..8c365d1 --- /dev/null +++ b/pipeline_jitable_env.py @@ -0,0 +1,120 @@ +""" +pipeline for jitable env like func_fit, gymnax +""" + +from functools import partial +from typing import Type + +import jax +import time +import numpy as np + +from algorithm import NEAT, HyperNEAT +from config import Config +from core import State, Algorithm, Problem + + +class Pipeline: + + def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]): + + assert problem_type.jitable, "problem must be jitable" + + self.config = config + self.algorithm = algorithm + self.problem = problem_type(config.problem) + + if isinstance(algorithm, NEAT): + assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" + + elif isinstance(algorithm, HyperNEAT): + assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" + + else: + raise NotImplementedError + + self.act_func = self.algorithm.act + + for _ in range(len(self.problem.input_shape) - 1): + self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) + + self.best_genome = None + self.best_fitness = float('-inf') + self.generation_timestamp = None + + def setup(self): + key = jax.random.PRNGKey(self.config.basic.seed) + algorithm_key, evaluate_key = jax.random.split(key, 2) + state = State() + state = self.algorithm.setup(algorithm_key, state) + return state.update( + evaluate_key=evaluate_key + ) + + @partial(jax.jit, static_argnums=(0,)) + def step(self, state): + + key, sub_key = jax.random.split(state.evaluate_key) + keys = jax.random.split(key, self.config.basic.pop_size) + + pop = self.algorithm.ask(state) + + pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) + + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, + pop_transformed) + + state = self.algorithm.tell(state, fitnesses) + + return state.update(evaluate_key=sub_key), fitnesses + + def auto_run(self, ini_state): + state = ini_state + for _ in range(self.config.basic.generation_limit): + + self.generation_timestamp = time.time() + + previous_pop = self.algorithm.ask(state) + + state, fitnesses = self.step(state) + + fitnesses = jax.device_get(fitnesses) + + self.analysis(state, previous_pop, fitnesses) + + if max(fitnesses) >= self.config.basic.fitness_target: + print("Fitness limit reached!") + return state, self.best_genome + + print("Generation limit reached!") + return state, self.best_genome + + def analysis(self, state, pop, fitnesses): + + max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) + + new_timestamp = time.time() + + cost_time = new_timestamp - self.generation_timestamp + + max_idx = np.argmax(fitnesses) + if fitnesses[max_idx] > self.best_fitness: + self.best_fitness = fitnesses[max_idx] + self.best_genome = pop[max_idx] + + member_count = jax.device_get(state.species_info.member_count) + species_sizes = [int(i) for i in member_count if i > 0] + + print(f"Generation: {state.generation}", + f"species: {len(species_sizes)}, {species_sizes}", + f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") + + def show(self, state, genome): + transformed = self.algorithm.transform(state, genome) + self.problem.show(state.evaluate_key, state, self.act_func, transformed) + + def pre_compile(self, state): + tic = time.time() + print("start compile") + self.step.lower(self, state).compile() + print(f"compile finished, cost time: {time.time() - tic}s") diff --git a/pipeline.py b/pipeline_time.py similarity index 94% rename from pipeline.py rename to pipeline_time.py index 0e4ca06..d54751a 100644 --- a/pipeline.py +++ b/pipeline_time.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Type import jax @@ -44,7 +43,6 @@ class Pipeline: evaluate_key=evaluate_key ) - @partial(jax.jit, static_argnums=(0,)) def step(self, state): key, sub_key = jax.random.split(state.evaluate_key) @@ -110,6 +108,4 @@ class Pipeline: tic = time.time() print("start compile") self.step.lower(self, state).compile() - # compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile() - # self.__dict__['step'] = compiled_step print(f"compile finished, cost time: {time.time() - tic}s") diff --git a/problem/func_fit/func_fit.py b/problem/func_fit/func_fit.py index 3e904fd..43f3043 100644 --- a/problem/func_fit/func_fit.py +++ b/problem/func_fit/func_fit.py @@ -17,6 +17,8 @@ class FuncFitConfig(ProblemConfig): class FuncFit(Problem): + jitable = True + def __init__(self, config: FuncFitConfig = FuncFitConfig()): self.config = config super().__init__(config) diff --git a/problem/rl_env/gym_env.py b/problem/rl_env/gym_env.py new file mode 100644 index 0000000..080152a --- /dev/null +++ b/problem/rl_env/gym_env.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import Callable + +import gym + +from core import State +from .rl_unjit import RLEnv, RLEnvConfig + + +@dataclass(frozen=True) +class GymConfig(RLEnvConfig): + env_name: str = "CartPole-v1" + + def __post_init__(self): + assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered" + + +class GymNaxEnv(RLEnv): + + def __init__(self, config: GymConfig = GymConfig()): + super().__init__(config) + self.config = config + self.env, self.env_params = gym.make(config.env_name) + + def env_step(self, randkey, env_state, action): + return self.env.step(randkey, env_state, action, self.env_params) + + def env_reset(self, randkey): + return self.env.reset(randkey, self.env_params) + + @property + def input_shape(self): + return self.env.observation_space(self.env_params).shape + + @property + def output_shape(self): + return self.env.action_space(self.env_params).shape + + def show(self, randkey, state: State, act_func: Callable, params): + raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/problem/rl_env/gymnax_env.py b/problem/rl_env/gymnax_env.py index f945e1d..872c63e 100644 --- a/problem/rl_env/gymnax_env.py +++ b/problem/rl_env/gymnax_env.py @@ -1,12 +1,10 @@ from dataclasses import dataclass from typing import Callable -import jax -import jax.numpy as jnp import gymnax from core import State -from .rl_env import RLEnv, RLEnvConfig +from .rl_jit import RLEnv, RLEnvConfig @dataclass(frozen=True) diff --git a/problem/rl_env/rl_env.py b/problem/rl_env/rl_jit.py similarity index 99% rename from problem/rl_env/rl_env.py rename to problem/rl_env/rl_jit.py index 2eaa9a0..0d12266 100644 --- a/problem/rl_env/rl_env.py +++ b/problem/rl_env/rl_jit.py @@ -16,6 +16,8 @@ class RLEnvConfig(ProblemConfig): class RLEnv(Problem): + jitable = True + def __init__(self, config: RLEnvConfig = RLEnvConfig()): super().__init__(config) self.config = config diff --git a/problem/rl_env/rl_unjit.py b/problem/rl_env/rl_unjit.py new file mode 100644 index 0000000..0ab7c3f --- /dev/null +++ b/problem/rl_env/rl_unjit.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Callable + +import jax + +from config import ProblemConfig + +from core import Problem, State + + +@dataclass(frozen=True) +class RLEnvConfig(ProblemConfig): + output_transform: Callable = lambda x: x + + +class RLEnv(Problem): + + jitable = False + + def __init__(self, config: RLEnvConfig = RLEnvConfig()): + super().__init__(config) + self.config = config + + def evaluate(self, randkey, state: State, act_func: Callable, params): + rng_reset, rng_episode = jax.random.split(randkey) + init_obs, init_env_state = self.reset(rng_reset) + + def cond_func(carry): + _, _, _, done, _ = carry + return ~done + + def body_func(carry): + obs, env_state, rng, _, tr = carry # total reward + net_out = act_func(state, obs, params) + action = self.config.output_transform(net_out) + next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) + next_rng, _ = jax.random.split(rng) + return next_obs, next_env_state, next_rng, done, tr + reward + + _, _, _, _, total_reward = jax.lax.while_loop( + cond_func, + body_func, + (init_obs, init_env_state, rng_episode, False, 0.0) + ) + + return total_reward + + def step(self, randkey, env_state, action): + return self.env_step(randkey, env_state, action) + + def reset(self, randkey): + return self.env_reset(randkey) + + def env_step(self, randkey, env_state, action): + raise NotImplementedError + + def env_reset(self, randkey): + raise NotImplementedError + + @property + def input_shape(self): + raise NotImplementedError + + @property + def output_shape(self): + raise NotImplementedError + + def show(self, randkey, state: State, act_func: Callable, params): + raise NotImplementedError