add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -1,3 +1,5 @@
import jax.numpy as jnp
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
@@ -29,6 +31,7 @@ name2sympy = {
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
"clip": SympyClip,
}
@@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]):
)
FUNCS_MODULE = {}
SYMPY_FUNCS_MODULE_NP = {}
SYMPY_FUNCS_MODULE_JNP = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)