update functions
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..base import BaseProblem
|
||||
@@ -19,7 +20,7 @@ class FuncFit(BaseProblem):
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
predict = vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
|
||||
@@ -41,7 +42,7 @@ class FuncFit(BaseProblem):
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
predict = vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
|
||||
Reference in New Issue
Block a user