finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -1,24 +1,27 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
@@ -38,7 +41,7 @@ class FuncFit(BaseProblem):
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""