update save method in pipeline

This commit is contained in:
root
2024-07-15 11:21:52 +08:00
parent aa8581cd11
commit f032564a43
2 changed files with 35 additions and 18 deletions

View File

@@ -33,11 +33,16 @@ class StatefulBaseClass:
return state
def show_config(self):
def show_config(self, registered_objects=None):
if registered_objects is None: # root call
registered_objects = []
config = {}
for key, value in self.__dict__.items():
if isinstance(value, StatefulBaseClass):
config[str(key)] = value.show_config()
if isinstance(value, StatefulBaseClass) and value not in registered_objects:
registered_objects.append(value)
config[str(key)] = value.show_config(registered_objects)
else:
config[str(key)] = str(value)
return config

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv
from .rl_jit import RLEnv, norm_obs
class BraxEnv(RLEnv):
@@ -37,13 +37,17 @@ class BraxEnv(RLEnv):
save_path=None,
height=480,
width=480,
output_type="rgb_array",
*args,
**kwargs,
):
assert output_type in ["rgb_array", "gif"]
import jax
import imageio
from brax.io import image
import numpy as np
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
@@ -52,13 +56,16 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
if self.obs_normalization:
obs = norm_obs(state, obs)
if self.action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(key, forward_func, obs)
else:
action = act_func(state, params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
obs, env_state, r, done, info = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
jit_step = jax.jit(step)
@@ -70,14 +77,19 @@ class BraxEnv(RLEnv):
if done:
break
print("Total reward: ", reward)
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
if output_type == "rgb_array":
imgs = np.array(imgs)
return imgs
if save_path is None:
save_path = f"{self.env_name}.gif"
imageio.mimsave(save_path, imgs, *args, **kwargs)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)