From 2c6adf737758fe67fc332ec6e08be5d9c78d709f Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 19 Apr 2025 20:15:24 +0800 Subject: [PATCH] make rl envs lazy import --- src/tensorneat/problem/rl/brax.py | 3 +-- src/tensorneat/problem/rl/gymnax.py | 3 +-- src/tensorneat/problem/rl/mujoco_playground.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index f7ac483..84edfba 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -1,6 +1,4 @@ import jax.numpy as jnp -from brax import envs - from .rl_jit import RLEnv, norm_obs @@ -8,6 +6,7 @@ class BraxEnv(RLEnv): def __init__( self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs ): + from brax import envs super().__init__(*args, **kwargs) self.env_name = env_name self.env = envs.create(env_name=env_name, backend=backend) diff --git a/src/tensorneat/problem/rl/gymnax.py b/src/tensorneat/problem/rl/gymnax.py index 4f17e25..9e78d81 100644 --- a/src/tensorneat/problem/rl/gymnax.py +++ b/src/tensorneat/problem/rl/gymnax.py @@ -1,10 +1,9 @@ -import gymnax - from .rl_jit import RLEnv class GymNaxEnv(RLEnv): def __init__(self, env_name, *args, **kwargs): + import gymnax super().__init__(*args, **kwargs) assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax." self.env, self.env_params = gymnax.make(env_name) diff --git a/src/tensorneat/problem/rl/mujoco_playground.py b/src/tensorneat/problem/rl/mujoco_playground.py index e0d0ec0..fecb15e 100644 --- a/src/tensorneat/problem/rl/mujoco_playground.py +++ b/src/tensorneat/problem/rl/mujoco_playground.py @@ -1,7 +1,5 @@ import jax.numpy as jnp from jax import Array -from mujoco_playground import registry - from .rl_jit import RLEnv, norm_obs @@ -9,6 +7,7 @@ class MujocoEnv(RLEnv): def __init__( self, env_name: str = "SwimmerSwimmer6", *args, **kwargs ): + from mujoco_playground import registry super().__init__(*args, **kwargs) self.env_name = env_name self.env = registry.load(env_name=env_name)