update save method in pipeline
This commit is contained in:
@@ -33,11 +33,16 @@ class StatefulBaseClass:
|
||||
|
||||
return state
|
||||
|
||||
def show_config(self):
|
||||
def show_config(self, registered_objects=None):
|
||||
if registered_objects is None: # root call
|
||||
registered_objects = []
|
||||
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, StatefulBaseClass):
|
||||
config[str(key)] = value.show_config()
|
||||
if isinstance(value, StatefulBaseClass) and value not in registered_objects:
|
||||
registered_objects.append(value)
|
||||
config[str(key)] = value.show_config(registered_objects)
|
||||
|
||||
else:
|
||||
config[str(key)] = str(value)
|
||||
return config
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from brax import envs
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
from .rl_jit import RLEnv, norm_obs
|
||||
|
||||
|
||||
class BraxEnv(RLEnv):
|
||||
@@ -29,21 +29,25 @@ class BraxEnv(RLEnv):
|
||||
return (self.env.action_size,)
|
||||
|
||||
def show(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
save_path=None,
|
||||
height=480,
|
||||
width=480,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
save_path=None,
|
||||
height=480,
|
||||
width=480,
|
||||
output_type="rgb_array",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert output_type in ["rgb_array", "gif"]
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
from brax.io import image
|
||||
import numpy as np
|
||||
|
||||
obs, env_state = self.reset(randkey)
|
||||
reward, done = 0.0, False
|
||||
@@ -52,13 +56,16 @@ class BraxEnv(RLEnv):
|
||||
def step(key, env_state, obs):
|
||||
key, _ = jax.random.split(key)
|
||||
|
||||
if self.obs_normalization:
|
||||
obs = norm_obs(state, obs)
|
||||
|
||||
if self.action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = self.action_policy(key, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
|
||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||
obs, env_state, r, done, info = self.step(randkey, env_state, action)
|
||||
return key, env_state, obs, r, done
|
||||
|
||||
jit_step = jax.jit(step)
|
||||
@@ -69,15 +76,20 @@ class BraxEnv(RLEnv):
|
||||
reward += r
|
||||
if done:
|
||||
break
|
||||
|
||||
print("Total reward: ", reward)
|
||||
|
||||
imgs = image.render_array(
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
||||
)
|
||||
|
||||
if output_type == "rgb_array":
|
||||
imgs = np.array(imgs)
|
||||
return imgs
|
||||
|
||||
if save_path is None:
|
||||
save_path = f"{self.env_name}.gif"
|
||||
|
||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||
|
||||
print("Gif saved to: ", save_path)
|
||||
print("Total reward: ", reward)
|
||||
|
||||
Reference in New Issue
Block a user