add args record_episode in rl tasks, with related test "test_record_episode.ipynb";

add args return_data in func_fit tasks.
This commit is contained in:
wls2002
2024-05-30 17:05:56 +08:00
parent 20320105e6
commit cd92f411dc
8 changed files with 512 additions and 22 deletions

View File

@@ -8,11 +8,12 @@ from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse"):
def __init__(self, error_method: str = "mse", return_data: bool = False):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()):
return state
@@ -38,7 +39,10 @@ class FuncFit(BaseProblem):
else:
raise NotImplementedError
return -loss
if self.return_data:
return -loss, self.inputs
else:
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, 0, None))(