update functions

This commit is contained in:
root
2024-07-12 02:14:48 +08:00
parent 45b4155541
commit 3194678a15
15 changed files with 323 additions and 378 deletions

View File

@@ -1,4 +1,5 @@
import jax
from jax import vmap, numpy as jnp
import jax.numpy as jnp
from ..base import BaseProblem
@@ -19,7 +20,7 @@ class FuncFit(BaseProblem):
def evaluate(self, state, randkey, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
predict = vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
@@ -41,7 +42,7 @@ class FuncFit(BaseProblem):
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
predict = vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])