From 7f042e07c249207a64f140146588419d7decd943 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 17 Oct 2023 20:20:03 +0800 Subject: [PATCH] add brax env --- examples/brax/ant.py | 37 ++++++++++++++++++++++++++++ examples/brax/half_cheetah.py | 41 +++++++++++++++++++++++++++++++ examples/brax/reacher.py | 38 +++++++++++++++++++++++++++++ examples/brax_env.py | 36 ++++++++++++++++++++++++++++ examples/general_xor.py | 5 ---- pipeline.py | 2 ++ problem/rl_env/__init__.py | 1 + problem/rl_env/brax_env.py | 45 +++++++++++++++++++++++++++++++++++ problem/rl_env/rl_jit.py | 2 +- 9 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 examples/brax/ant.py create mode 100644 examples/brax/half_cheetah.py create mode 100644 examples/brax/reacher.py create mode 100644 examples/brax_env.py create mode 100644 problem/rl_env/brax_env.py diff --git a/examples/brax/ant.py b/examples/brax/ant.py new file mode 100644 index 0000000..f4034b1 --- /dev/null +++ b/examples/brax/ant.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp + +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import BraxEnv, BraxConfig + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=10000, + pop_size=100 + ), + neat=NeatConfig( + inputs=27, + outputs=8, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=BraxConfig( + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, BraxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/brax/half_cheetah.py b/examples/brax/half_cheetah.py new file mode 100644 index 0000000..eb2baf9 --- /dev/null +++ b/examples/brax/half_cheetah.py @@ -0,0 +1,41 @@ +import jax.numpy as jnp + +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import BraxEnv, BraxConfig + + +# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d'] + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=10000, + pop_size=10000 + ), + neat=NeatConfig( + inputs=17, + outputs=6, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=BraxConfig( + env_name="halfcheetah" + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, BraxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/brax/reacher.py b/examples/brax/reacher.py new file mode 100644 index 0000000..af31765 --- /dev/null +++ b/examples/brax/reacher.py @@ -0,0 +1,38 @@ +import jax.numpy as jnp + +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import BraxEnv, BraxConfig + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=10000, + pop_size=10000 + ), + neat=NeatConfig( + inputs=11, + outputs=2, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=BraxConfig( + env_name="reacher" + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, BraxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) diff --git a/examples/brax_env.py b/examples/brax_env.py new file mode 100644 index 0000000..61d94e7 --- /dev/null +++ b/examples/brax_env.py @@ -0,0 +1,36 @@ +import jax + +import brax +from brax import envs + + +def inference_func(key, *args): + return jax.random.normal(key, shape=(env.action_size,)) + + +env_name = "ant" +backend = "generalized" + +env = envs.create(env_name=env_name, backend=backend) + +jit_env_reset = jax.jit(env.reset) +jit_env_step = jax.jit(env.step) +jit_inference_fn = jax.jit(inference_func) + + +rollout = [] +rng = jax.random.PRNGKey(seed=1) +ori_state = jit_env_reset(rng=rng) +state = ori_state + +for _ in range(100): + rollout.append(state.pipeline_state) + act_rng, rng = jax.random.split(rng) + act = jit_inference_fn(act_rng, state.obs) + state = jit_env_step(state, act) + reward = state.reward + # print(reward) + +a = 1 + + diff --git a/examples/general_xor.py b/examples/general_xor.py index edcb994..a2d45ee 100644 --- a/examples/general_xor.py +++ b/examples/general_xor.py @@ -4,11 +4,6 @@ from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.func_fit import XOR, FuncFitConfig -def evaluate(): - pass - - - if __name__ == '__main__': config = Config( basic=BasicConfig( diff --git a/pipeline.py b/pipeline.py index 8c365d1..f6ebbea 100644 --- a/pipeline.py +++ b/pipeline.py @@ -24,6 +24,8 @@ class Pipeline: self.algorithm = algorithm self.problem = problem_type(config.problem) + print(self.problem.input_shape, self.problem.output_shape) + if isinstance(algorithm, NEAT): assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" diff --git a/problem/rl_env/__init__.py b/problem/rl_env/__init__.py index 63e273a..acc4653 100644 --- a/problem/rl_env/__init__.py +++ b/problem/rl_env/__init__.py @@ -1 +1,2 @@ from .gymnax_env import GymNaxEnv, GymNaxConfig +from .brax_env import BraxEnv, BraxConfig diff --git a/problem/rl_env/brax_env.py b/problem/rl_env/brax_env.py new file mode 100644 index 0000000..710834b --- /dev/null +++ b/problem/rl_env/brax_env.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from typing import Callable + +import jax.numpy as jnp +from brax import envs +from core import State +from .rl_jit import RLEnv, RLEnvConfig + + +@dataclass(frozen=True) +class BraxConfig(RLEnvConfig): + env_name: str = "ant" + backend: str = "generalized" + + def __post_init__(self): + # TODO: Check if env_name is registered + # assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered" + pass + + +class BraxEnv(RLEnv): + def __init__(self, config: BraxConfig = BraxConfig()): + super().__init__(config) + self.config = config + self.env = envs.create(env_name=config.env_name, backend=config.backend) + + def env_step(self, randkey, env_state, action): + state = self.env.step(env_state, action) + return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info + + def env_reset(self, randkey): + init_state = self.env.reset(randkey) + return init_state.obs, init_state + + @property + def input_shape(self): + return (self.env.observation_size, ) + + @property + def output_shape(self): + return (self.env.action_size, ) + + def show(self, randkey, state: State, act_func: Callable, params): + # TODO + raise NotImplementedError("im busy! to de done!") diff --git a/problem/rl_env/rl_jit.py b/problem/rl_env/rl_jit.py index 0d12266..f8244d2 100644 --- a/problem/rl_env/rl_jit.py +++ b/problem/rl_env/rl_jit.py @@ -29,10 +29,10 @@ class RLEnv(Problem): def cond_func(carry): _, _, _, done, _ = carry return ~done - def body_func(carry): obs, env_state, rng, _, tr = carry # total reward net_out = act_func(state, obs, params) + action = self.config.output_transform(net_out) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_rng, _ = jax.random.split(rng)