complete show() in brax env

This commit is contained in:
wls2002
2023-10-22 21:01:06 +08:00
parent 7f042e07c2
commit 15dadebd7e
11 changed files with 152 additions and 22 deletions

View File

@@ -32,7 +32,6 @@ class RLEnv(Problem):
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
@@ -68,5 +67,5 @@ class RLEnv(Problem):
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
raise NotImplementedError