add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -51,7 +51,7 @@ class BraxEnv(RLEnv):
|
||||
|
||||
def step(key, env_state, obs):
|
||||
key, _ = jax.random.split(key)
|
||||
action = act_func(obs, params)
|
||||
action = act_func(params, obs)
|
||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||
return key, env_state, obs, r, done
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class RLEnv(BaseProblem):
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
|
||||
action = act_func(state, obs, params)
|
||||
action = act_func(state, params, obs)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(
|
||||
rng, env_state, action
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user