update problem and pipeline

This commit is contained in:
root
2024-07-11 19:34:12 +08:00
parent be6a67d7e2
commit cef27b56bb
14 changed files with 40 additions and 205 deletions

View File

@@ -1,19 +1,18 @@
import jax
import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse", return_data: bool = False):
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()):
return state
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
else:
raise NotImplementedError
if self.return_data:
return -loss, self.inputs
else:
return -loss
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
fitness = self.evaluate(state, randkey, act_func, params)
loss = -fitness
msg = ""
for i in range(inputs.shape[0]):