add jumanji env;

add repeat times for rl_env
This commit is contained in:
wls2002
2024-06-05 14:24:17 +08:00
parent edfb0596e7
commit 10ec1c2df9
10 changed files with 1615 additions and 7 deletions

View File

@@ -46,7 +46,7 @@ class FuncFit(BaseProblem):
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs, params
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data: