add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -20,8 +20,8 @@ class FuncFit(BaseProblem):
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
|
||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
||||
state, self.inputs, params
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
|
||||
if self.error_method == "mse":
|
||||
@@ -45,8 +45,8 @@ class FuncFit(BaseProblem):
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
||||
state, self.inputs, params
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs, params
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
if self.return_data:
|
||||
|
||||
Reference in New Issue
Block a user