add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

View File

@@ -1,6 +1,51 @@
from .activation import Act, act_func, ACT_ALL
from .aggregation import Agg, agg_func, AGG_ALL
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
from .state import State
from .stateful_class import StatefulBaseClass
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Union
name2sympy = {
"sigmoid": SympySigmoid,
"tanh": SympyTanh,
"sin": SympySin,
"relu": SympyRelu,
"lelu": SympyLelu,
"identity": SympyIdentity,
"clamped": SympyClamped,
"inv": SympyInv,
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"sum": SympySum,
"product": SympyProduct,
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
}
def convert_to_sympy(func: Union[str, callable]):
if isinstance(func, str):
name = func
else:
name = func.__name__
if name in name2sympy:
return name2sympy[name]
else:
raise ValueError(
f"Can not convert to sympy! Function {name} not found in name2sympy"
)
FUNCS_MODULE = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
FUNCS_MODULE[cls.__name__] = cls.numerical_eval