diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index f7ac483..84edfba 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -1,6 +1,4 @@ import jax.numpy as jnp -from brax import envs - from .rl_jit import RLEnv, norm_obs @@ -8,6 +6,7 @@ class BraxEnv(RLEnv): def __init__( self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs ): + from brax import envs super().__init__(*args, **kwargs) self.env_name = env_name self.env = envs.create(env_name=env_name, backend=backend) diff --git a/src/tensorneat/problem/rl/gymnax.py b/src/tensorneat/problem/rl/gymnax.py index 4f17e25..9e78d81 100644 --- a/src/tensorneat/problem/rl/gymnax.py +++ b/src/tensorneat/problem/rl/gymnax.py @@ -1,10 +1,9 @@ -import gymnax - from .rl_jit import RLEnv class GymNaxEnv(RLEnv): def __init__(self, env_name, *args, **kwargs): + import gymnax super().__init__(*args, **kwargs) assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax." self.env, self.env_params = gymnax.make(env_name) diff --git a/src/tensorneat/problem/rl/mujoco_playground.py b/src/tensorneat/problem/rl/mujoco_playground.py index e0d0ec0..fecb15e 100644 --- a/src/tensorneat/problem/rl/mujoco_playground.py +++ b/src/tensorneat/problem/rl/mujoco_playground.py @@ -1,7 +1,5 @@ import jax.numpy as jnp from jax import Array -from mujoco_playground import registry - from .rl_jit import RLEnv, norm_obs @@ -9,6 +7,7 @@ class MujocoEnv(RLEnv): def __init__( self, env_name: str = "SwimmerSwimmer6", *args, **kwargs ): + from mujoco_playground import registry super().__init__(*args, **kwargs) self.env_name = env_name self.env = registry.load(env_name=env_name)