update save method in pipeline
This commit is contained in:
@@ -33,11 +33,16 @@ class StatefulBaseClass:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def show_config(self):
|
def show_config(self, registered_objects=None):
|
||||||
|
if registered_objects is None: # root call
|
||||||
|
registered_objects = []
|
||||||
|
|
||||||
config = {}
|
config = {}
|
||||||
for key, value in self.__dict__.items():
|
for key, value in self.__dict__.items():
|
||||||
if isinstance(value, StatefulBaseClass):
|
if isinstance(value, StatefulBaseClass) and value not in registered_objects:
|
||||||
config[str(key)] = value.show_config()
|
registered_objects.append(value)
|
||||||
|
config[str(key)] = value.show_config(registered_objects)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config[str(key)] = str(value)
|
config[str(key)] = str(value)
|
||||||
return config
|
return config
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from brax import envs
|
from brax import envs
|
||||||
|
|
||||||
from .rl_jit import RLEnv
|
from .rl_jit import RLEnv, norm_obs
|
||||||
|
|
||||||
|
|
||||||
class BraxEnv(RLEnv):
|
class BraxEnv(RLEnv):
|
||||||
@@ -29,21 +29,25 @@ class BraxEnv(RLEnv):
|
|||||||
return (self.env.action_size,)
|
return (self.env.action_size,)
|
||||||
|
|
||||||
def show(
|
def show(
|
||||||
self,
|
self,
|
||||||
state,
|
state,
|
||||||
randkey,
|
randkey,
|
||||||
act_func,
|
act_func,
|
||||||
params,
|
params,
|
||||||
save_path=None,
|
save_path=None,
|
||||||
height=480,
|
height=480,
|
||||||
width=480,
|
width=480,
|
||||||
*args,
|
output_type="rgb_array",
|
||||||
**kwargs,
|
*args,
|
||||||
):
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert output_type in ["rgb_array", "gif"]
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import imageio
|
import imageio
|
||||||
from brax.io import image
|
from brax.io import image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
obs, env_state = self.reset(randkey)
|
obs, env_state = self.reset(randkey)
|
||||||
reward, done = 0.0, False
|
reward, done = 0.0, False
|
||||||
@@ -52,13 +56,16 @@ class BraxEnv(RLEnv):
|
|||||||
def step(key, env_state, obs):
|
def step(key, env_state, obs):
|
||||||
key, _ = jax.random.split(key)
|
key, _ = jax.random.split(key)
|
||||||
|
|
||||||
|
if self.obs_normalization:
|
||||||
|
obs = norm_obs(state, obs)
|
||||||
|
|
||||||
if self.action_policy is not None:
|
if self.action_policy is not None:
|
||||||
forward_func = lambda obs: act_func(state, params, obs)
|
forward_func = lambda obs: act_func(state, params, obs)
|
||||||
action = self.action_policy(key, forward_func, obs)
|
action = self.action_policy(key, forward_func, obs)
|
||||||
else:
|
else:
|
||||||
action = act_func(state, params, obs)
|
action = act_func(state, params, obs)
|
||||||
|
|
||||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
obs, env_state, r, done, info = self.step(randkey, env_state, action)
|
||||||
return key, env_state, obs, r, done
|
return key, env_state, obs, r, done
|
||||||
|
|
||||||
jit_step = jax.jit(step)
|
jit_step = jax.jit(step)
|
||||||
@@ -69,15 +76,20 @@ class BraxEnv(RLEnv):
|
|||||||
reward += r
|
reward += r
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
print("Total reward: ", reward)
|
||||||
|
|
||||||
imgs = image.render_array(
|
imgs = image.render_array(
|
||||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if output_type == "rgb_array":
|
||||||
|
imgs = np.array(imgs)
|
||||||
|
return imgs
|
||||||
|
|
||||||
if save_path is None:
|
if save_path is None:
|
||||||
save_path = f"{self.env_name}.gif"
|
save_path = f"{self.env_name}.gif"
|
||||||
|
|
||||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||||
|
|
||||||
print("Gif saved to: ", save_path)
|
print("Gif saved to: ", save_path)
|
||||||
print("Total reward: ", reward)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user