add backend="jax" to sympy module
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
@@ -29,6 +31,7 @@ name2sympy = {
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]):
|
||||
)
|
||||
|
||||
|
||||
FUNCS_MODULE = {}
|
||||
SYMPY_FUNCS_MODULE_NP = {}
|
||||
SYMPY_FUNCS_MODULE_JNP = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Union
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
@@ -13,8 +12,8 @@ class SympyClip(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(val, min_val, max_val):
|
||||
return np.clip(val, min_val, max_val)
|
||||
def numerical_eval(val, min_val, max_val, backend=np):
|
||||
return backend.clip(val, min_val, max_val)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||
@@ -32,9 +31,9 @@ class SympySigmoid(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + np.exp(-z))
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + backend.exp(-z))
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
@@ -49,8 +48,8 @@ class SympyTanh(sp.Function):
|
||||
return sp.tanh(0.6 * z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.tanh(0.6 * z)
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.tanh(0.6 * z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
@@ -59,8 +58,8 @@ class SympySin(sp.Function):
|
||||
return sp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.sin(z)
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.sin(z)
|
||||
|
||||
|
||||
class SympyRelu(sp.Function):
|
||||
@@ -71,8 +70,8 @@ class SympyRelu(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.maximum(z, 0)
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.maximum(z, 0)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"relu({self.args[0]})"
|
||||
@@ -90,9 +89,9 @@ class SympyLelu(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
def numerical_eval(z, backend=np):
|
||||
leaky = 0.005
|
||||
return np.maximum(z, leaky * z)
|
||||
return backend.maximum(z, leaky * z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"lelu({self.args[0]})"
|
||||
@@ -107,7 +106,7 @@ class SympyIdentity(sp.Function):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
def numerical_eval(z, backend=np):
|
||||
return z
|
||||
|
||||
|
||||
@@ -117,8 +116,8 @@ class SympyClamped(sp.Function):
|
||||
return SympyClip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.clip(z, -1, 1)
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.clip(z, -1, 1)
|
||||
|
||||
|
||||
class SympyInv(sp.Function):
|
||||
@@ -130,8 +129,8 @@ class SympyInv(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
@@ -150,9 +149,9 @@ class SympyLog(sp.Function):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
return np.log(z)
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return backend.log(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"log({self.args[0]})"
|
||||
@@ -169,11 +168,6 @@ class SympyExp(sp.Function):
|
||||
return sp.exp(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(z, -10, 10)
|
||||
return np.exp(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"exp({self.args[0]})"
|
||||
|
||||
@@ -185,7 +179,3 @@ class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Abs(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.abs(z)
|
||||
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
|
||||
@@ -51,15 +52,6 @@ class SympyMedian(sp.Function):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user