add backend="jax" to sympy module
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
@@ -29,6 +31,7 @@ name2sympy = {
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]):
|
||||
)
|
||||
|
||||
|
||||
FUNCS_MODULE = {}
|
||||
SYMPY_FUNCS_MODULE_NP = {}
|
||||
SYMPY_FUNCS_MODULE_JNP = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||
|
||||
Reference in New Issue
Block a user