add brax env

This commit is contained in:
wls2002
2023-10-17 20:20:03 +08:00
parent f217d87ac6
commit 7f042e07c2
9 changed files with 201 additions and 6 deletions

View File

@@ -1 +1,2 @@
from .gymnax_env import GymNaxEnv, GymNaxConfig
from .brax_env import BraxEnv, BraxConfig

View File

@@ -0,0 +1,45 @@
from dataclasses import dataclass
from typing import Callable
import jax.numpy as jnp
from brax import envs
from core import State
from .rl_jit import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class BraxConfig(RLEnvConfig):
env_name: str = "ant"
backend: str = "generalized"
def __post_init__(self):
# TODO: Check if env_name is registered
# assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
pass
class BraxEnv(RLEnv):
def __init__(self, config: BraxConfig = BraxConfig()):
super().__init__(config)
self.config = config
self.env = envs.create(env_name=config.env_name, backend=config.backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
def env_reset(self, randkey):
init_state = self.env.reset(randkey)
return init_state.obs, init_state
@property
def input_shape(self):
return (self.env.observation_size, )
@property
def output_shape(self):
return (self.env.action_size, )
def show(self, randkey, state: State, act_func: Callable, params):
# TODO
raise NotImplementedError("im busy! to de done!")

View File

@@ -29,10 +29,10 @@ class RLEnv(Problem):
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)