add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
@@ -1,6 +1,51 @@
|
||||
from .activation import Act, act_func, ACT_ALL
|
||||
from .aggregation import Agg, agg_func, AGG_ALL
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
|
||||
from .activation.act_jnp import Act, ACT_ALL, act_func
|
||||
from .aggregation.agg_sympy import *
|
||||
from .activation.act_sympy import *
|
||||
|
||||
from typing import Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
"identity": SympyIdentity,
|
||||
"clamped": SympyClamped,
|
||||
"inv": SympyInv,
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"sum": SympySum,
|
||||
"product": SympyProduct,
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
}
|
||||
|
||||
|
||||
def convert_to_sympy(func: Union[str, callable]):
|
||||
if isinstance(func, str):
|
||||
name = func
|
||||
else:
|
||||
name = func.__name__
|
||||
if name in name2sympy:
|
||||
return name2sympy[name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not convert to sympy! Function {name} not found in name2sympy"
|
||||
)
|
||||
|
||||
|
||||
FUNCS_MODULE = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
|
||||
|
||||
0
tensorneat/utils/activation/__init__.py
Normal file
0
tensorneat/utils/activation/__init__.py
Normal file
@@ -3,6 +3,10 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Act:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Act, name)
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
191
tensorneat/utils/activation/act_sympy.py
Normal file
191
tensorneat/utils/activation/act_sympy.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from typing import Union
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SympyClip(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, val, min_val, max_val):
|
||||
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
||||
return sp.Piecewise(
|
||||
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(val, min_val, max_val):
|
||||
return np.clip(val, min_val, max_val)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z, -10, 10)
|
||||
return 1 / (1 + sp.exp(-z))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + np.exp(-z))
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.tanh(0.6 * z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.tanh(0.6 * z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.sin(z)
|
||||
|
||||
|
||||
class SympyRelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
return sp.Piecewise((z, z > 0), (0, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.maximum(z, 0)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"relu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyLelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
leaky = 0.005
|
||||
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
leaky = 0.005
|
||||
return np.maximum(z, leaky * z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"lelu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyIdentity(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return z
|
||||
|
||||
|
||||
class SympyClamped(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympyClip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.clip(z, -1, 1)
|
||||
|
||||
|
||||
class SympyInv(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
||||
return 1 / z
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"1 / {self.args[0]}"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
|
||||
|
||||
|
||||
class SympyLog(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Max(z, 1e-7)
|
||||
return sp.log(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
return np.log(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"log({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyExp(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -10, 10)
|
||||
return sp.exp(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(z, -10, 10)
|
||||
return np.exp(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"exp({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Abs(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.abs(z)
|
||||
0
tensorneat/utils/aggregation/__init__.py
Normal file
0
tensorneat/utils/aggregation/__init__.py
Normal file
@@ -3,6 +3,10 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Agg, name)
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
69
tensorneat/utils/aggregation/agg_sympy.py
Normal file
69
tensorneat/utils/aggregation/agg_sympy.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class SympySum(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Min(*z)
|
||||
|
||||
|
||||
class SympyMaxabs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z, key=sp.Abs)
|
||||
|
||||
|
||||
class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
@@ -5,6 +5,7 @@ Only used in feed-forward networks.
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
from typing import Tuple, Set, List, Union
|
||||
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
@@ -41,6 +42,60 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
def topological_sort_python(
|
||||
nodes: Union[Set[int], List[int]],
|
||||
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||
) -> Tuple[List[int], List[List[int]]]:
|
||||
# a python version of topological_sort, use python set to store nodes and conns
|
||||
# returns the topological order of the nodes and the topological layers
|
||||
# written by gpt4 :)
|
||||
|
||||
# Make a copy of the input nodes and connections
|
||||
nodes = nodes.copy()
|
||||
conns = conns.copy()
|
||||
|
||||
# Initialize the in-degree of each node to 0
|
||||
in_degree = {node: 0 for node in nodes}
|
||||
|
||||
# Compute the in-degree for each node
|
||||
for conn in conns:
|
||||
in_degree[conn[1]] += 1
|
||||
|
||||
topo_order = []
|
||||
topo_layer = []
|
||||
|
||||
# Find all nodes with in-degree 0
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
while zero_in_degree_nodes:
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
zero_in_degree_nodes = sorted(
|
||||
zero_in_degree_nodes
|
||||
) # make sure the topo_order is from small to large
|
||||
|
||||
topo_layer.append(zero_in_degree_nodes.copy())
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
topo_order.append(node)
|
||||
|
||||
# Iterate over all connections and reduce the in-degree of connected nodes
|
||||
for conn in list(conns):
|
||||
if conn[0] == node:
|
||||
in_degree[conn[1]] -= 1
|
||||
conns.remove(conn)
|
||||
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
# Check if there are still connections left indicating a cycle
|
||||
if conns or nodes:
|
||||
raise ValueError("Graph has at least one cycle, topological sort not possible")
|
||||
|
||||
return topo_order, topo_layer
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user