add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -1,4 +1,3 @@
from typing import Union
import sympy as sp
import numpy as np
@@ -13,8 +12,8 @@ class SympyClip(sp.Function):
return None
@staticmethod
def numerical_eval(val, min_val, max_val):
return np.clip(val, min_val, max_val)
def numerical_eval(val, min_val, max_val, backend=np):
return backend.clip(val, min_val, max_val)
def _sympystr(self, printer):
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
@@ -32,9 +31,9 @@ class SympySigmoid(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.clip(5 * z, -10, 10)
return 1 / (1 + np.exp(-z))
def numerical_eval(z, backend=np):
z = backend.clip(5 * z, -10, 10)
return 1 / (1 + backend.exp(-z))
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
@@ -49,8 +48,8 @@ class SympyTanh(sp.Function):
return sp.tanh(0.6 * z)
@staticmethod
def numerical_eval(z):
return np.tanh(0.6 * z)
def numerical_eval(z, backend=np):
return backend.tanh(0.6 * z)
class SympySin(sp.Function):
@@ -59,8 +58,8 @@ class SympySin(sp.Function):
return sp.sin(z)
@staticmethod
def numerical_eval(z):
return np.sin(z)
def numerical_eval(z, backend=np):
return backend.sin(z)
class SympyRelu(sp.Function):
@@ -71,8 +70,8 @@ class SympyRelu(sp.Function):
return None
@staticmethod
def numerical_eval(z):
return np.maximum(z, 0)
def numerical_eval(z, backend=np):
return backend.maximum(z, 0)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
@@ -90,9 +89,9 @@ class SympyLelu(sp.Function):
return None
@staticmethod
def numerical_eval(z):
def numerical_eval(z, backend=np):
leaky = 0.005
return np.maximum(z, leaky * z)
return backend.maximum(z, leaky * z)
def _sympystr(self, printer):
return f"lelu({self.args[0]})"
@@ -107,7 +106,7 @@ class SympyIdentity(sp.Function):
return z
@staticmethod
def numerical_eval(z):
def numerical_eval(z, backend=np):
return z
@@ -117,8 +116,8 @@ class SympyClamped(sp.Function):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z):
return np.clip(z, -1, 1)
def numerical_eval(z, backend=np):
return backend.clip(z, -1, 1)
class SympyInv(sp.Function):
@@ -130,8 +129,8 @@ class SympyInv(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
def numerical_eval(z, backend=np):
z = backend.maximum(z, 1e-7)
return 1 / z
def _sympystr(self, printer):
@@ -150,9 +149,9 @@ class SympyLog(sp.Function):
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def numerical_eval(z, backend=np):
z = backend.maximum(z, 1e-7)
return backend.log(z)
def _sympystr(self, printer):
return f"log({self.args[0]})"
@@ -169,11 +168,6 @@ class SympyExp(sp.Function):
return sp.exp(z)
return None
@staticmethod
def numerical_eval(z):
z = np.clip(z, -10, 10)
return np.exp(z)
def _sympystr(self, printer):
return f"exp({self.args[0]})"
@@ -185,7 +179,3 @@ class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Abs(z)
@staticmethod
def numerical_eval(z):
return np.abs(z)