diff --git a/tensorneat/examples/brax/show_test.py b/tensorneat/examples/brax/show_test.py new file mode 100644 index 0000000..c50920e --- /dev/null +++ b/tensorneat/examples/brax/show_test.py @@ -0,0 +1,19 @@ +import jax +from problem.rl_env import BraxEnv + + +def random_policy(randkey, forward_func, obs): + return jax.random.uniform(randkey, (6,), minval=-1, maxval=1) + + +if __name__ == "__main__": + problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy) + state = problem.setup() + randkey = jax.random.key(0) + problem.show( + state, + randkey, + act_func=lambda state, params, obs: obs, + params=None, + save_path="walker2d_random_policy", + ) diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index f3adb15..5ff5157 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -9,6 +9,7 @@ class BraxEnv(RLEnv): self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs ): super().__init__(*args, **kwargs) + self.env_name = env_name self.env = envs.create(env_name=env_name, backend=backend) def env_step(self, randkey, env_state, action): @@ -34,45 +35,49 @@ class BraxEnv(RLEnv): act_func, params, save_path=None, - height=512, - width=512, - duration=0.1, + height=480, + width=480, *args, - **kwargs + **kwargs, ): import jax import imageio - import numpy as np from brax.io import image - from tqdm import tqdm obs, env_state = self.reset(randkey) reward, done = 0.0, False - state_histories = [] + state_histories = [env_state.pipeline_state] def step(key, env_state, obs): key, _ = jax.random.split(key) - action = act_func(params, obs) + + if self.action_policy is not None: + forward_func = lambda obs: act_func(state, params, obs) + action = self.action_policy(key, forward_func, obs) + else: + action = act_func(state, params, obs) + obs, env_state, r, done, _ = self.step(randkey, env_state, action) return key, env_state, obs, r, done - while not done: + jit_step = jax.jit(step) + + for _ in range(self.max_step): + key, env_state, obs, r, done = jit_step(randkey, env_state, obs) state_histories.append(env_state.pipeline_state) - key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs) reward += r + if done: + break - imgs = [ - image.render_array(sys=self.env.sys, state=s, width=width, height=height) - for s in tqdm(state_histories, desc="Rendering") - ] + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width + ) - def create_gif(image_list, gif_name, duration): - with imageio.get_writer(gif_name, mode="I", duration=duration) as writer: - for image in image_list: - formatted_image = np.array(image, dtype=np.uint8) - writer.append_data(formatted_image) + if save_path is None: + save_path = f"{self.env_name}.gif" + + imageio.mimsave(save_path, imgs, *args, **kwargs) - create_gif(imgs, save_path, duration=0.1) print("Gif saved to: ", save_path) print("Total reward: ", reward)