complete show() in brax env

This commit is contained in:
wls2002
2023-10-22 21:01:06 +08:00
parent 7f042e07c2
commit 15dadebd7e
11 changed files with 152 additions and 22 deletions

View File

@@ -44,7 +44,7 @@ class FuncFit(Problem):
return -loss
def show(self, randkey, state: State, act_func: Callable, params):
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)