add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -20,8 +20,8 @@ class FuncFit(BaseProblem):
def evaluate(self, state, randkey, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
if self.error_method == "mse":
@@ -45,8 +45,8 @@ class FuncFit(BaseProblem):
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs, params
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:

View File

@@ -51,7 +51,7 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
action = act_func(obs, params)
action = act_func(params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done

View File

@@ -36,7 +36,7 @@ class RLEnv(BaseProblem):
def body_func(carry):
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
action = act_func(state, obs, params)
action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action
)