add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -51,7 +51,7 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
action = act_func(obs, params)
action = act_func(params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done