From f032564a43d37075aebd9f6b363cf2dbe7401927 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 15 Jul 2024 11:21:52 +0800 Subject: [PATCH] update save method in pipeline --- src/tensorneat/common/stateful_class.py | 11 +++++-- src/tensorneat/problem/rl/brax.py | 42 ++++++++++++++++--------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/tensorneat/common/stateful_class.py b/src/tensorneat/common/stateful_class.py index 5a87f30..a7e2c2a 100644 --- a/src/tensorneat/common/stateful_class.py +++ b/src/tensorneat/common/stateful_class.py @@ -33,11 +33,16 @@ class StatefulBaseClass: return state - def show_config(self): + def show_config(self, registered_objects=None): + if registered_objects is None: # root call + registered_objects = [] + config = {} for key, value in self.__dict__.items(): - if isinstance(value, StatefulBaseClass): - config[str(key)] = value.show_config() + if isinstance(value, StatefulBaseClass) and value not in registered_objects: + registered_objects.append(value) + config[str(key)] = value.show_config(registered_objects) + else: config[str(key)] = str(value) return config diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index 5ff5157..c7b2942 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from brax import envs -from .rl_jit import RLEnv +from .rl_jit import RLEnv, norm_obs class BraxEnv(RLEnv): @@ -29,21 +29,25 @@ class BraxEnv(RLEnv): return (self.env.action_size,) def show( - self, - state, - randkey, - act_func, - params, - save_path=None, - height=480, - width=480, - *args, - **kwargs, - ): + self, + state, + randkey, + act_func, + params, + save_path=None, + height=480, + width=480, + output_type="rgb_array", + *args, + **kwargs, + ): + + assert output_type in ["rgb_array", "gif"] import jax import imageio from brax.io import image + import numpy as np obs, env_state = self.reset(randkey) reward, done = 0.0, False @@ -52,13 +56,16 @@ class BraxEnv(RLEnv): def step(key, env_state, obs): key, _ = jax.random.split(key) + if self.obs_normalization: + obs = norm_obs(state, obs) + if self.action_policy is not None: forward_func = lambda obs: act_func(state, params, obs) action = self.action_policy(key, forward_func, obs) else: action = act_func(state, params, obs) - obs, env_state, r, done, _ = self.step(randkey, env_state, action) + obs, env_state, r, done, info = self.step(randkey, env_state, action) return key, env_state, obs, r, done jit_step = jax.jit(step) @@ -69,15 +76,20 @@ class BraxEnv(RLEnv): reward += r if done: break + + print("Total reward: ", reward) imgs = image.render_array( - sys=self.env.sys, trajectory=state_histories, height=height, width=width + sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" ) + if output_type == "rgb_array": + imgs = np.array(imgs) + return imgs + if save_path is None: save_path = f"{self.env_name}.gif" imageio.mimsave(save_path, imgs, *args, **kwargs) print("Gif saved to: ", save_path) - print("Total reward: ", reward)