update problem and pipeline
This commit is contained in:
@@ -1,19 +1,18 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self, error_method: str = "mse", return_data: bool = False):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
self.return_data = return_data
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.return_data:
|
||||
return -loss, self.inputs
|
||||
else:
|
||||
return -loss
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
if self.return_data:
|
||||
loss, _ = self.evaluate(state, randkey, act_func, params)
|
||||
else:
|
||||
loss = self.evaluate(state, randkey, act_func, params)
|
||||
loss = -loss
|
||||
fitness = self.evaluate(state, randkey, act_func, params)
|
||||
|
||||
loss = -fitness
|
||||
|
||||
msg = ""
|
||||
for i in range(inputs.shape[0]):
|
||||
|
||||
Reference in New Issue
Block a user