update problem and pipeline

This commit is contained in:
root
2024-07-11 19:34:12 +08:00
parent be6a67d7e2
commit cef27b56bb
14 changed files with 40 additions and 205 deletions

View File

@@ -1,4 +1,4 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d
from .custom import CustomFuncFit
from .custom import CustomFuncFit
from .func_fit import FuncFit

View File

@@ -1,6 +1,4 @@
from typing import Callable, Union, List, Tuple, Sequence
import jax
from typing import Callable, Union, List, Tuple
from jax import vmap, Array, numpy as jnp
import numpy as np

View File

@@ -1,19 +1,18 @@
import jax
import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse", return_data: bool = False):
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()):
return state
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
else:
raise NotImplementedError
if self.return_data:
return -loss, self.inputs
else:
return -loss
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
fitness = self.evaluate(state, randkey, act_func, params)
loss = -fitness
msg = ""
for i in range(inputs.shape[0]):