fix show error of cartpolev1

This commit is contained in:
Zenbook
2025-02-17 23:34:33 +08:00
parent d205d34ad2
commit d86a3196bd

View File

@@ -41,7 +41,7 @@ class BraxEnv(RLEnv):
*args,
**kwargs,
):
assert output_type in ["rgb_array", "gif"]
import jax
@@ -76,12 +76,17 @@ class BraxEnv(RLEnv):
reward += r
if done:
break
print("Total reward: ", reward)
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
try:
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
except ValueError:
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width
)
if output_type == "rgb_array":
imgs = np.array(imgs)