complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -25,7 +25,19 @@ class BraxEnv(RLEnv):
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
def show(
self,
state,
randkey,
act_func,
params,
save_path=None,
height=512,
width=512,
duration=0.1,
*args,
**kwargs
):
import jax
import imageio
@@ -48,11 +60,13 @@ class BraxEnv(RLEnv):
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
tqdm(state_histories, desc="Rendering")]
imgs = [
image.render_array(sys=self.env.sys, state=s, width=width, height=height)
for s in tqdm(state_histories, desc="Rendering")
]
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
with imageio.get_writer(gif_name, mode="I", duration=duration) as writer:
for image in image_list:
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
@@ -60,5 +74,3 @@ class BraxEnv(RLEnv):
create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)