fix show error of cartpolev1
This commit is contained in:
@@ -41,7 +41,7 @@ class BraxEnv(RLEnv):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert output_type in ["rgb_array", "gif"]
|
assert output_type in ["rgb_array", "gif"]
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
@@ -76,12 +76,17 @@ class BraxEnv(RLEnv):
|
|||||||
reward += r
|
reward += r
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
print("Total reward: ", reward)
|
print("Total reward: ", reward)
|
||||||
|
|
||||||
imgs = image.render_array(
|
try:
|
||||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
imgs = image.render_array(
|
||||||
)
|
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
imgs = image.render_array(
|
||||||
|
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||||
|
)
|
||||||
|
|
||||||
if output_type == "rgb_array":
|
if output_type == "rgb_array":
|
||||||
imgs = np.array(imgs)
|
imgs = np.array(imgs)
|
||||||
|
|||||||
Reference in New Issue
Block a user