add mp4 format to show; fix a bug in visualize

This commit is contained in:
Zenbook
2025-02-23 18:16:44 +08:00
parent d86a3196bd
commit c2566c3931
2 changed files with 72 additions and 7 deletions

View File

@@ -42,7 +42,7 @@ class BraxEnv(RLEnv):
**kwargs,
):
assert output_type in ["rgb_array", "gif"]
assert output_type in ["rgb_array", "gif", "mp4"]
import jax
import imageio
@@ -93,8 +93,14 @@ class BraxEnv(RLEnv):
return imgs
if save_path is None:
save_path = f"{self.env_name}.gif"
save_path = f"{self.env_name}.{output_type}"
imageio.mimsave(save_path, imgs, *args, **kwargs)
print("Gif saved to: ", save_path)
if output_type == "gif":
imageio.mimsave(save_path, imgs, *args, **kwargs)
elif output_type == "mp4":
fps = kwargs.get("fps", 30)
imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4")
print(f"{output_type} saved to: ", save_path)