add args record_episode in rl tasks, with related test "test_record_episode.ipynb";
add args return_data in func_fit tasks.
This commit is contained in:
@@ -8,11 +8,12 @@ from .. import BaseProblem
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
def __init__(self, error_method: str = "mse", return_data: bool = False):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
self.return_data = return_data
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
@@ -38,7 +39,10 @@ class FuncFit(BaseProblem):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return -loss
|
||||
if self.return_data:
|
||||
return -loss, self.inputs
|
||||
else:
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
||||
|
||||
@@ -4,8 +4,6 @@ from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
|
||||
@@ -4,9 +4,6 @@ from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR3d(FuncFit):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array(
|
||||
|
||||
Reference in New Issue
Block a user