diff --git a/examples/mujoco_playground/swimmer6.py b/examples/mujoco_playground/swimmer6.py new file mode 100644 index 0000000..4084838 --- /dev/null +++ b/examples/mujoco_playground/swimmer6.py @@ -0,0 +1,61 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation + +from tensorneat.problem.rl import MujocoEnv +from tensorneat.common import ACT, AGG +import jax + + +def random_sample_policy(randkey, obs): + return jax.random.uniform(randkey, (8,), minval=-1.0, maxval=1.0) + + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=3000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=0.8, + genome=DefaultGenome( + max_nodes=100, + max_conns=1500, + num_inputs=25, + num_outputs=5, + init_hidden_layers=(30,), + mutation=DefaultMutation( + node_delete=0.0, + ), + node_gene=BiasNode( + bias_init_std=0.1, + bias_mutate_power=0.05, + bias_mutate_rate=0.01, + bias_replace_rate=0.0, + activation_options=ACT.tanh, + aggregation_options=AGG.sum, + ), + conn_gene=DefaultConn( + weight_init_mean=0.0, + weight_init_std=0.1, + weight_mutate_power=0.05, + weight_replace_rate=0.0, + weight_mutate_rate=0.001, + ), + output_transform=ACT.tanh, + ), + ), + problem=MujocoEnv( + env_name="SwimmerSwimmer6", + max_step=1000, + ), + seed=42, + generation_limit=100, + fitness_target=8000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/src/tensorneat/problem/rl/__init__.py b/src/tensorneat/problem/rl/__init__.py index 60f754e..7100a3f 100644 --- a/src/tensorneat/problem/rl/__init__.py +++ b/src/tensorneat/problem/rl/__init__.py @@ -1,3 +1,4 @@ from .gymnax import GymNaxEnv from .brax import BraxEnv -from .rl_jit import RLEnv \ No newline at end of file +from .rl_jit import RLEnv +from .mujoco_playground import MujocoEnv \ No newline at end of file diff --git a/src/tensorneat/problem/rl/mujoco_playground.py b/src/tensorneat/problem/rl/mujoco_playground.py new file mode 100644 index 0000000..e0d0ec0 --- /dev/null +++ b/src/tensorneat/problem/rl/mujoco_playground.py @@ -0,0 +1,119 @@ +import jax.numpy as jnp +from jax import Array +from mujoco_playground import registry + +from .rl_jit import RLEnv, norm_obs + + +class MujocoEnv(RLEnv): + def __init__( + self, env_name: str = "SwimmerSwimmer6", *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.env_name = env_name + self.env = registry.load(env_name=env_name) + + def env_step(self, randkey, env_state, action): + state = self.env.step(env_state, action) + obs = state.obs + if not isinstance(obs, Array): + if "state" in obs: + obs = obs["state"] + else: + raise ImportError( + f"This Pytree observation space is not supported yet: {obs}" + ) + return obs, state, state.reward, state.done.astype(jnp.bool_), state.info + + def env_reset(self, randkey): + init_state = self.env.reset(randkey) + obs = init_state.obs + if not isinstance(obs, Array): + if "state" in obs: + obs = obs["state"] + else: + raise ImportError( + f"This Pytree observation space is not supported yet: {obs}" + ) + return obs, init_state + + @property + def input_shape(self): + return (self.env.observation_size,) + + @property + def output_shape(self): + return (self.env.action_size,) + + def show( + self, + state, + randkey, + act_func, + params, + save_path=None, + height=480, + width=480, + output_type="rgb_array", + *args, + **kwargs, + ): + + assert output_type in ["gif", "mp4"] + + import jax + import imageio + from brax.io import image + import numpy as np + + obs, env_state = self.reset(randkey) + reward, done = 0.0, False + state_histories = [env_state.pipeline_state] + + def step(key, env_state, obs): + key, _ = jax.random.split(key) + + if self.obs_normalization: + obs = norm_obs(state, obs) + + if self.action_policy is not None: + forward_func = lambda obs: act_func(state, params, obs) + action = self.action_policy(key, forward_func, obs) + else: + action = act_func(state, params, obs) + + obs, env_state, r, done, info = self.step(randkey, env_state, action) + return key, env_state, obs, r, done + + jit_step = jax.jit(step) + + for _ in range(self.max_step): + key, env_state, obs, r, done = jit_step(randkey, env_state, obs) + state_histories.append(env_state.pipeline_state) + reward += r + if done: + break + + print("Total reward: ", reward) + + try: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" + ) + except ValueError: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width + ) + + if save_path is None: + save_path = f"{self.env_name}.{output_type}" + + imageio.mimsave(save_path, imgs, *args, **kwargs) + + if output_type == "gif": + imageio.mimsave(save_path, imgs, *args, **kwargs) + elif output_type == "mp4": + fps = kwargs.get("fps", 30) + imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4") + + print(f"{output_type} saved to: ", save_path)