add jumanji env;
add repeat times for rl_env
This commit is contained in:
@@ -46,7 +46,7 @@ class FuncFit(BaseProblem):
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs, params
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
if self.return_data:
|
||||
|
||||
Reference in New Issue
Block a user