diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index c7b2942..cf6e84d 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -41,7 +41,7 @@ class BraxEnv(RLEnv): *args, **kwargs, ): - + assert output_type in ["rgb_array", "gif"] import jax @@ -76,12 +76,17 @@ class BraxEnv(RLEnv): reward += r if done: break - + print("Total reward: ", reward) - imgs = image.render_array( - sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" - ) + try: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" + ) + except ValueError: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width + ) if output_type == "rgb_array": imgs = np.array(imgs)