make rl envs lazy import

This commit is contained in:
wls2002
2025-04-19 20:15:24 +08:00
parent 3d57ca4cdd
commit 2c6adf7377
3 changed files with 3 additions and 6 deletions

View File

@@ -1,6 +1,4 @@
import jax.numpy as jnp import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv, norm_obs from .rl_jit import RLEnv, norm_obs
@@ -8,6 +6,7 @@ class BraxEnv(RLEnv):
def __init__( def __init__(
self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs
): ):
from brax import envs
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.env_name = env_name self.env_name = env_name
self.env = envs.create(env_name=env_name, backend=backend) self.env = envs.create(env_name=env_name, backend=backend)

View File

@@ -1,10 +1,9 @@
import gymnax
from .rl_jit import RLEnv from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name, *args, **kwargs): def __init__(self, env_name, *args, **kwargs):
import gymnax
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax." assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
self.env, self.env_params = gymnax.make(env_name) self.env, self.env_params = gymnax.make(env_name)

View File

@@ -1,7 +1,5 @@
import jax.numpy as jnp import jax.numpy as jnp
from jax import Array from jax import Array
from mujoco_playground import registry
from .rl_jit import RLEnv, norm_obs from .rl_jit import RLEnv, norm_obs
@@ -9,6 +7,7 @@ class MujocoEnv(RLEnv):
def __init__( def __init__(
self, env_name: str = "SwimmerSwimmer6", *args, **kwargs self, env_name: str = "SwimmerSwimmer6", *args, **kwargs
): ):
from mujoco_playground import registry
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.env_name = env_name self.env_name = env_name
self.env = registry.load(env_name=env_name) self.env = registry.load(env_name=env_name)