diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index 9c34501..4e0b505 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -39,7 +39,7 @@ class BraxEnv(RLEnv): def step(key, env_state, obs): key, _ = jax.random.split(key) - action = act_func(state, obs, params) + action = act_func(obs, params) obs, env_state, r, done, _ = self.step(randkey, env_state, action) return key, env_state, obs, r, done