add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -1,3 +1,5 @@
import jax.numpy as jnp
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
@@ -29,6 +31,7 @@ name2sympy = {
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
"clip": SympyClip,
}
@@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]):
)
FUNCS_MODULE = {}
SYMPY_FUNCS_MODULE_NP = {}
SYMPY_FUNCS_MODULE_JNP = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)

View File

@@ -1,4 +1,3 @@
from typing import Union
import sympy as sp
import numpy as np
@@ -13,8 +12,8 @@ class SympyClip(sp.Function):
return None
@staticmethod
def numerical_eval(val, min_val, max_val):
return np.clip(val, min_val, max_val)
def numerical_eval(val, min_val, max_val, backend=np):
return backend.clip(val, min_val, max_val)
def _sympystr(self, printer):
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
@@ -32,9 +31,9 @@ class SympySigmoid(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.clip(5 * z, -10, 10)
return 1 / (1 + np.exp(-z))
def numerical_eval(z, backend=np):
z = backend.clip(5 * z, -10, 10)
return 1 / (1 + backend.exp(-z))
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
@@ -49,8 +48,8 @@ class SympyTanh(sp.Function):
return sp.tanh(0.6 * z)
@staticmethod
def numerical_eval(z):
return np.tanh(0.6 * z)
def numerical_eval(z, backend=np):
return backend.tanh(0.6 * z)
class SympySin(sp.Function):
@@ -59,8 +58,8 @@ class SympySin(sp.Function):
return sp.sin(z)
@staticmethod
def numerical_eval(z):
return np.sin(z)
def numerical_eval(z, backend=np):
return backend.sin(z)
class SympyRelu(sp.Function):
@@ -71,8 +70,8 @@ class SympyRelu(sp.Function):
return None
@staticmethod
def numerical_eval(z):
return np.maximum(z, 0)
def numerical_eval(z, backend=np):
return backend.maximum(z, 0)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
@@ -90,9 +89,9 @@ class SympyLelu(sp.Function):
return None
@staticmethod
def numerical_eval(z):
def numerical_eval(z, backend=np):
leaky = 0.005
return np.maximum(z, leaky * z)
return backend.maximum(z, leaky * z)
def _sympystr(self, printer):
return f"lelu({self.args[0]})"
@@ -107,7 +106,7 @@ class SympyIdentity(sp.Function):
return z
@staticmethod
def numerical_eval(z):
def numerical_eval(z, backend=np):
return z
@@ -117,8 +116,8 @@ class SympyClamped(sp.Function):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z):
return np.clip(z, -1, 1)
def numerical_eval(z, backend=np):
return backend.clip(z, -1, 1)
class SympyInv(sp.Function):
@@ -130,8 +129,8 @@ class SympyInv(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
def numerical_eval(z, backend=np):
z = backend.maximum(z, 1e-7)
return 1 / z
def _sympystr(self, printer):
@@ -150,9 +149,9 @@ class SympyLog(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def numerical_eval(z, backend=np):
z = backend.maximum(z, 1e-7)
return backend.log(z)
def _sympystr(self, printer):
return f"log({self.args[0]})"
@@ -169,11 +168,6 @@ class SympyExp(sp.Function):
return sp.exp(z)
return None
@staticmethod
def numerical_eval(z):
z = np.clip(z, -10, 10)
return np.exp(z)
def _sympystr(self, printer):
return f"exp({self.args[0]})"
@@ -185,7 +179,3 @@ class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Abs(z)
@staticmethod
def numerical_eval(z):
return np.abs(z)

View File

@@ -1,3 +1,4 @@
import numpy as np
import sympy as sp
@@ -51,15 +52,6 @@ class SympyMedian(sp.Function):
return None
@staticmethod
def numerical_eval(args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"