add mp4 format to show; fix a bug in visualize
This commit is contained in:
@@ -42,7 +42,7 @@ class BraxEnv(RLEnv):
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert output_type in ["rgb_array", "gif"]
|
||||
assert output_type in ["rgb_array", "gif", "mp4"]
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
@@ -93,8 +93,14 @@ class BraxEnv(RLEnv):
|
||||
return imgs
|
||||
|
||||
if save_path is None:
|
||||
save_path = f"{self.env_name}.gif"
|
||||
save_path = f"{self.env_name}.{output_type}"
|
||||
|
||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||
|
||||
print("Gif saved to: ", save_path)
|
||||
if output_type == "gif":
|
||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||
elif output_type == "mp4":
|
||||
fps = kwargs.get("fps", 30)
|
||||
imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4")
|
||||
|
||||
print(f"{output_type} saved to: ", save_path)
|
||||
|
||||
Reference in New Issue
Block a user