add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

View File

View File

@@ -0,0 +1,89 @@
import jax
import jax.numpy as jnp
class Act:
@staticmethod
def name2func(name):
return getattr(Act, name)
@staticmethod
def sigmoid(z):
z = jnp.clip(5 * z, -10, 10)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh(z):
return jnp.tanh(0.6 * z)
@staticmethod
def sin(z):
return jnp.sin(z)
@staticmethod
def relu(z):
return jnp.maximum(z, 0)
@staticmethod
def lelu(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def identity(z):
return z
@staticmethod
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv(z):
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
return 1 / z
@staticmethod
def log(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp(z):
z = jnp.clip(z, -10, 10)
return jnp.exp(z)
@staticmethod
def abs(z):
return jnp.abs(z)
ACT_ALL = (
Act.sigmoid,
Act.tanh,
Act.sin,
Act.relu,
Act.lelu,
Act.identity,
Act.clamped,
Act.inv,
Act.log,
Act.exp,
Act.abs,
)
def act_func(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
# -1 means identity activation
res = jax.lax.cond(
idx == -1,
lambda: z,
lambda: jax.lax.switch(idx, act_funcs, z),
)
return res

View File

@@ -0,0 +1,191 @@
from typing import Union
import sympy as sp
import numpy as np
class SympyClip(sp.Function):
@classmethod
def eval(cls, val, min_val, max_val):
if val.is_Number and min_val.is_Number and max_val.is_Number:
return sp.Piecewise(
(min_val, val < min_val), (max_val, val > max_val), (val, True)
)
return None
@staticmethod
def numerical_eval(val, min_val, max_val):
return np.clip(val, min_val, max_val)
def _sympystr(self, printer):
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
def _latex(self, printer):
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
class SympySigmoid(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z, -10, 10)
return 1 / (1 + sp.exp(-z))
return None
@staticmethod
def numerical_eval(z):
z = np.clip(5 * z, -10, 10)
return 1 / (1 + np.exp(-z))
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
return sp.tanh(0.6 * z)
@staticmethod
def numerical_eval(z):
return np.tanh(0.6 * z)
class SympySin(sp.Function):
@classmethod
def eval(cls, z):
return sp.sin(z)
@staticmethod
def numerical_eval(z):
return np.sin(z)
class SympyRelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
return sp.Piecewise((z, z > 0), (0, True))
return None
@staticmethod
def numerical_eval(z):
return np.maximum(z, 0)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
class SympyLelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
leaky = 0.005
return sp.Piecewise((z, z > 0), (leaky * z, True))
return None
@staticmethod
def numerical_eval(z):
leaky = 0.005
return np.maximum(z, leaky * z)
def _sympystr(self, printer):
return f"lelu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
return z
@staticmethod
def numerical_eval(z):
return z
class SympyClamped(sp.Function):
@classmethod
def eval(cls, z):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z):
return np.clip(z, -1, 1)
class SympyInv(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
return 1 / z
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return 1 / z
def _sympystr(self, printer):
return f"1 / {self.args[0]}"
def _latex(self, printer):
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
class SympyLog(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Max(z, 1e-7)
return sp.log(z)
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def _sympystr(self, printer):
return f"log({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
class SympyExp(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -10, 10)
return sp.exp(z)
return None
@staticmethod
def numerical_eval(z):
z = np.clip(z, -10, 10)
return np.exp(z)
def _sympystr(self, printer):
return f"exp({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Abs(z)
@staticmethod
def numerical_eval(z):
return np.abs(z)