update save method in pipeline

This commit is contained in:
root
2024-07-15 11:21:52 +08:00
parent aa8581cd11
commit f032564a43
2 changed files with 35 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv
from .rl_jit import RLEnv, norm_obs
class BraxEnv(RLEnv):
@@ -29,21 +29,25 @@ class BraxEnv(RLEnv):
return (self.env.action_size,)
def show(
self,
state,
randkey,
act_func,
params,
save_path=None,
height=480,
width=480,
*args,
**kwargs,
):
self,
state,
randkey,
act_func,
params,
save_path=None,
height=480,
width=480,
output_type="rgb_array",
*args,
**kwargs,
):
assert output_type in ["rgb_array", "gif"]
import jax
import imageio
from brax.io import image
import numpy as np
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
@@ -52,13 +56,16 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
if self.obs_normalization:
obs = norm_obs(state, obs)
if self.action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(key, forward_func, obs)
else:
action = act_func(state, params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
obs, env_state, r, done, info = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
jit_step = jax.jit(step)
@@ -69,15 +76,20 @@ class BraxEnv(RLEnv):
reward += r
if done:
break
print("Total reward: ", reward)
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
if output_type == "rgb_array":
imgs = np.array(imgs)
return imgs
if save_path is None:
save_path = f"{self.env_name}.gif"
imageio.mimsave(save_path, imgs, *args, **kwargs)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)