complete show() in brax env
This commit is contained in:
@@ -44,7 +44,7 @@ class FuncFit(Problem):
|
||||
|
||||
return -loss
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
|
||||
predict = act_func(state, self.inputs, params)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
loss = -self.evaluate(randkey, state, act_func, params)
|
||||
|
||||
Reference in New Issue
Block a user