add brax env

This commit is contained in:
wls2002
2023-10-17 20:20:03 +08:00
parent f217d87ac6
commit 7f042e07c2
9 changed files with 201 additions and 6 deletions

View File

@@ -29,10 +29,10 @@ class RLEnv(Problem):
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)