update functions
This commit is contained in:
@@ -1,110 +0,0 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
|
|
||||||
sigma_3 = 2.576
|
|
||||||
|
|
||||||
|
|
||||||
class Act:
|
|
||||||
@staticmethod
|
|
||||||
def name2func(name):
|
|
||||||
return getattr(Act, name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sigmoid(z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
z = 1 / (1 + jnp.exp(-z))
|
|
||||||
|
|
||||||
return z * sigma_3 # (0, sigma_3)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def standard_sigmoid(z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
z = 1 / (1 + jnp.exp(-z))
|
|
||||||
|
|
||||||
return z # (0, 1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tanh(z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def standard_tanh(z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
return jnp.tanh(z) # (-1, 1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sin(z):
|
|
||||||
z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
|
|
||||||
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def relu(z):
|
|
||||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
|
||||||
return jnp.maximum(z, 0) # (0, sigma_3)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def lelu(z):
|
|
||||||
leaky = 0.005
|
|
||||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
|
||||||
return jnp.where(z > 0, z, leaky * z)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def identity(z):
|
|
||||||
return z
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def inv(z):
|
|
||||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
|
||||||
return 1 / z
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def log(z):
|
|
||||||
z = jnp.maximum(z, 1e-7)
|
|
||||||
return jnp.log(z)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def exp(z):
|
|
||||||
z = jnp.clip(z, -10, 10)
|
|
||||||
return jnp.exp(z)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def square(z):
|
|
||||||
return jnp.pow(z, 2)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def abs(z):
|
|
||||||
z = jnp.clip(z, -1, 1)
|
|
||||||
return jnp.abs(z)
|
|
||||||
|
|
||||||
|
|
||||||
ACT_ALL = (
|
|
||||||
Act.sigmoid,
|
|
||||||
Act.tanh,
|
|
||||||
Act.sin,
|
|
||||||
Act.relu,
|
|
||||||
Act.lelu,
|
|
||||||
Act.identity,
|
|
||||||
Act.inv,
|
|
||||||
Act.log,
|
|
||||||
Act.exp,
|
|
||||||
Act.abs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def act_func(idx, z, act_funcs):
|
|
||||||
"""
|
|
||||||
calculate activation function for each node
|
|
||||||
"""
|
|
||||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
|
||||||
# change idx from float to int
|
|
||||||
|
|
||||||
# -1 means identity activation
|
|
||||||
res = jax.lax.cond(
|
|
||||||
idx == -1,
|
|
||||||
lambda: z,
|
|
||||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
import sympy as sp
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
sigma_3 = 2.576
|
|
||||||
|
|
||||||
|
|
||||||
class SympyClip(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, val, min_val, max_val):
|
|
||||||
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
|
||||||
return sp.Piecewise(
|
|
||||||
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(val, min_val, max_val, backend=np):
|
|
||||||
return backend.clip(val, min_val, max_val)
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympySigmoid_(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
z = 1 / (1 + sp.exp(-z))
|
|
||||||
return z
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
z = 1 / (1 + backend.exp(-z))
|
|
||||||
return z
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"sigmoid({self.args[0]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympySigmoid(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return SympySigmoid_(5 * z / sigma_3) * sigma_3
|
|
||||||
|
|
||||||
|
|
||||||
class SympyStandardSigmoid(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return SympySigmoid_(5 * z / sigma_3)
|
|
||||||
|
|
||||||
|
|
||||||
class SympyTanh(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
return sp.tanh(z) * sigma_3
|
|
||||||
|
|
||||||
|
|
||||||
class SympyStandardTanh(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
z = 5 * z / sigma_3
|
|
||||||
return sp.tanh(z)
|
|
||||||
|
|
||||||
|
|
||||||
class SympySin(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
|
|
||||||
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
|
|
||||||
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
|
||||||
|
|
||||||
|
|
||||||
class SympyRelu(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
z = SympyClip(z, -sigma_3, sigma_3)
|
|
||||||
return sp.Max(z, 0) # (0, sigma_3)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
z = backend.clip(z, -sigma_3, sigma_3)
|
|
||||||
return backend.maximum(z, 0) # (0, sigma_3)
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"relu({self.args[0]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympyLelu(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
leaky = 0.005
|
|
||||||
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
leaky = 0.005
|
|
||||||
return backend.maximum(z, leaky * z)
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"lelu({self.args[0]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympyIdentity(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class SympyInv(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
|
||||||
return 1 / z
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
z = backend.maximum(z, 1e-7)
|
|
||||||
return 1 / z
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"1 / {self.args[0]}"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
|
|
||||||
|
|
||||||
|
|
||||||
class SympyLog(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
z = sp.Max(z, 1e-7)
|
|
||||||
return sp.log(z)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
z = backend.maximum(z, 1e-7)
|
|
||||||
return backend.log(z)
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"log({self.args[0]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympyExp(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
if z.is_Number:
|
|
||||||
z = SympyClip(z, -10, 10)
|
|
||||||
return sp.exp(z)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
|
||||||
return f"exp({self.args[0]})"
|
|
||||||
|
|
||||||
def _latex(self, printer):
|
|
||||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
|
||||||
|
|
||||||
|
|
||||||
class SympySquare(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return sp.Pow(z, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class SympyAbs(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return sp.Abs(z)
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
|
|
||||||
class Agg:
|
|
||||||
@staticmethod
|
|
||||||
def name2func(name):
|
|
||||||
return getattr(Agg, name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sum(z):
|
|
||||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def product(z):
|
|
||||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def max(z):
|
|
||||||
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def min(z):
|
|
||||||
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def maxabs(z):
|
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
|
||||||
abs_z = jnp.abs(z)
|
|
||||||
max_abs_index = jnp.argmax(abs_z)
|
|
||||||
return z[max_abs_index]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def median(z):
|
|
||||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
|
||||||
|
|
||||||
z = jnp.sort(z) # sort
|
|
||||||
|
|
||||||
idx1, idx2 = (n - 1) // 2, n // 2
|
|
||||||
median = (z[idx1] + z[idx2]) / 2
|
|
||||||
|
|
||||||
return median
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mean(z):
|
|
||||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
|
||||||
valid_values_sum = jnp.sum(aux, axis=0)
|
|
||||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
|
||||||
mean_without_zeros = valid_values_sum / valid_values_count
|
|
||||||
return mean_without_zeros
|
|
||||||
|
|
||||||
|
|
||||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
|
||||||
|
|
||||||
|
|
||||||
def agg_func(idx, z, agg_funcs):
|
|
||||||
"""
|
|
||||||
calculate activation function for inputs of node
|
|
||||||
"""
|
|
||||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
|
||||||
|
|
||||||
return jax.lax.cond(
|
|
||||||
jnp.all(jnp.isnan(z)),
|
|
||||||
lambda: jnp.nan, # all inputs are nan
|
|
||||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
|
||||||
)
|
|
||||||
58
src/tensorneat/common/functions/__init__.py
Normal file
58
src/tensorneat/common/functions/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from .act_jnp import *
|
||||||
|
from .act_sympy import *
|
||||||
|
from .agg_jnp import *
|
||||||
|
from .agg_sympy import *
|
||||||
|
from .manager import FunctionManager
|
||||||
|
|
||||||
|
act_name2jnp = {
|
||||||
|
"scaled_sigmoid": scaled_sigmoid_,
|
||||||
|
"sigmoid": sigmoid_,
|
||||||
|
"scaled_tanh": scaled_tanh_,
|
||||||
|
"tanh": tanh_,
|
||||||
|
"sin": sin_,
|
||||||
|
"relu": relu_,
|
||||||
|
"lelu": lelu_,
|
||||||
|
"identity": identity_,
|
||||||
|
"inv": inv_,
|
||||||
|
"log": log_,
|
||||||
|
"exp": exp_,
|
||||||
|
"abs": abs_,
|
||||||
|
}
|
||||||
|
|
||||||
|
act_name2sympy = {
|
||||||
|
"scaled_sigmoid": SympyScaledSigmoid,
|
||||||
|
"sigmoid": SympySigmoid,
|
||||||
|
"scaled_tanh": SympyScaledTanh,
|
||||||
|
"tanh": SympyTanh,
|
||||||
|
"sin": SympySin,
|
||||||
|
"relu": SympyRelu,
|
||||||
|
"lelu": SympyLelu,
|
||||||
|
"identity": SympyIdentity,
|
||||||
|
"inv": SympyIdentity,
|
||||||
|
"log": SympyLog,
|
||||||
|
"exp": SympyExp,
|
||||||
|
"abs": SympyAbs,
|
||||||
|
}
|
||||||
|
|
||||||
|
agg_name2jnp = {
|
||||||
|
"sum": sum_,
|
||||||
|
"product": product_,
|
||||||
|
"max": max_,
|
||||||
|
"min": min_,
|
||||||
|
"maxabs": maxabs_,
|
||||||
|
"median": median_,
|
||||||
|
"mean": mean_,
|
||||||
|
}
|
||||||
|
|
||||||
|
agg_name2sympy = {
|
||||||
|
"sum": SympySum,
|
||||||
|
"product": SympyProduct,
|
||||||
|
"max": SympyMax,
|
||||||
|
"min": SympyMin,
|
||||||
|
"maxabs": SympyMaxabs,
|
||||||
|
"median": SympyMedian,
|
||||||
|
"mean": SympyMean,
|
||||||
|
}
|
||||||
|
|
||||||
|
ACT = FunctionManager(act_name2jnp, act_name2sympy)
|
||||||
|
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
|
||||||
57
src/tensorneat/common/functions/act_jnp.py
Normal file
57
src/tensorneat/common/functions/act_jnp.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
SCALE = 5
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_sigmoid_(z):
|
||||||
|
z = 1 / (1 + jnp.exp(-z))
|
||||||
|
return z * SCALE
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_(z):
|
||||||
|
z = 1 / (1 + jnp.exp(-z))
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_tanh_(z):
|
||||||
|
return jnp.tanh(z) * SCALE
|
||||||
|
|
||||||
|
|
||||||
|
def tanh_(z):
|
||||||
|
return jnp.tanh(z)
|
||||||
|
|
||||||
|
|
||||||
|
def sin_(z):
|
||||||
|
return jnp.sin(z)
|
||||||
|
|
||||||
|
|
||||||
|
def relu_(z):
|
||||||
|
return jnp.maximum(z, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def lelu_(z):
|
||||||
|
leaky = 0.005
|
||||||
|
return jnp.where(z > 0, z, leaky * z)
|
||||||
|
|
||||||
|
|
||||||
|
def identity_(z):
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
def inv_(z):
|
||||||
|
# avoid division by zero
|
||||||
|
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||||
|
return 1 / z
|
||||||
|
|
||||||
|
|
||||||
|
def log_(z):
|
||||||
|
z = jnp.maximum(z, 1e-7)
|
||||||
|
return jnp.log(z)
|
||||||
|
|
||||||
|
|
||||||
|
def exp_(z):
|
||||||
|
return jnp.exp(z)
|
||||||
|
|
||||||
|
|
||||||
|
def abs_(z):
|
||||||
|
return jnp.abs(z)
|
||||||
100
src/tensorneat/common/functions/act_sympy.py
Normal file
100
src/tensorneat/common/functions/act_sympy.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import sympy as sp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
SCALE = 5
|
||||||
|
|
||||||
|
class SympySigmoid(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
z = 1 / (1 + sp.exp(-z))
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class SympyScaledSigmoid(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return SympySigmoid(z) * SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class SympyTanh(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return sp.tanh(z)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyScaledTanh(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return SympyTanh(z) * SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class SympySin(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return sp.sin(z)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyRelu(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return sp.Max(z, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyLelu(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
leaky = 0.005
|
||||||
|
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
||||||
|
|
||||||
|
|
||||||
|
class SympyIdentity(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class SympyInv(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
||||||
|
return 1 / z
|
||||||
|
|
||||||
|
|
||||||
|
class SympyLog(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
z = sp.Max(z, 1e-7)
|
||||||
|
return sp.log(z)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyExp(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
z = SympyClip(z, -10, 10)
|
||||||
|
return sp.exp(z)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyAbs(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
return sp.Abs(z)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyClip(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, val, min_val, max_val):
|
||||||
|
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
||||||
|
return sp.Piecewise(
|
||||||
|
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numerical_eval(val, min_val, max_val, backend=np):
|
||||||
|
return backend.clip(val, min_val, max_val)
|
||||||
|
|
||||||
|
def _sympystr(self, printer):
|
||||||
|
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||||
|
|
||||||
|
def _latex(self, printer):
|
||||||
|
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||||
41
src/tensorneat/common/functions/agg_jnp.py
Normal file
41
src/tensorneat/common/functions/agg_jnp.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
def sum_(z):
|
||||||
|
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
||||||
|
|
||||||
|
|
||||||
|
def product_(z):
|
||||||
|
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
||||||
|
|
||||||
|
|
||||||
|
def max_(z):
|
||||||
|
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
||||||
|
|
||||||
|
|
||||||
|
def min_(z):
|
||||||
|
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
||||||
|
|
||||||
|
|
||||||
|
def maxabs_(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
|
abs_z = jnp.abs(z)
|
||||||
|
max_abs_index = jnp.argmax(abs_z)
|
||||||
|
return z[max_abs_index]
|
||||||
|
|
||||||
|
|
||||||
|
def median_(z):
|
||||||
|
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||||
|
|
||||||
|
z = jnp.sort(z) # sort
|
||||||
|
|
||||||
|
idx1, idx2 = (n - 1) // 2, n // 2
|
||||||
|
median = (z[idx1] + z[idx2]) / 2
|
||||||
|
|
||||||
|
return median
|
||||||
|
|
||||||
|
|
||||||
|
def mean_(z):
|
||||||
|
sumation = sum_(z)
|
||||||
|
valid_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||||
|
return sumation / valid_count
|
||||||
@@ -17,12 +17,20 @@ class SympyProduct(sp.Function):
|
|||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.Mul(*z)
|
return sp.Mul(*z)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def numerical_eval(cls, z, backend=np):
|
||||||
|
return backend.product(z)
|
||||||
|
|
||||||
|
|
||||||
class SympyMax(sp.Function):
|
class SympyMax(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.Max(*z)
|
return sp.Max(*z)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def numerical_eval(cls, z, backend=np):
|
||||||
|
return backend.max(z)
|
||||||
|
|
||||||
|
|
||||||
class SympyMin(sp.Function):
|
class SympyMin(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
49
src/tensorneat/common/functions/manager.py
Normal file
49
src/tensorneat/common/functions/manager.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from typing import Union, Callable
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
|
class FunctionManager:
|
||||||
|
|
||||||
|
def __init__(self, name2jnp, name2sympy):
|
||||||
|
self.name2jnp = name2jnp
|
||||||
|
self.name2sympy = name2sympy
|
||||||
|
|
||||||
|
def get_all_funcs(self):
|
||||||
|
all_funcs = []
|
||||||
|
for name in self.names:
|
||||||
|
all_funcs.append(getattr(self, name))
|
||||||
|
return all_funcs
|
||||||
|
|
||||||
|
def __getattribute__(self, name: str):
|
||||||
|
return self.name2jnp[name]
|
||||||
|
|
||||||
|
def add_func(self, name, func):
|
||||||
|
if not callable(func):
|
||||||
|
raise ValueError("The provided function is not callable")
|
||||||
|
if name in self.names:
|
||||||
|
raise ValueError(f"The provided name={name} is already in use")
|
||||||
|
|
||||||
|
self.name2jnp[name] = func
|
||||||
|
|
||||||
|
def update_sympy(self, name, sympy_cls: sp.Function):
|
||||||
|
self.name2sympy[name] = sympy_cls
|
||||||
|
|
||||||
|
def obtain_sympy(self, func: Union[str, Callable]):
|
||||||
|
if isinstance(func, str):
|
||||||
|
if func not in self.name2sympy:
|
||||||
|
raise ValueError(f"Func {func} doesn't have a sympy representation.")
|
||||||
|
return self.name2sympy[func]
|
||||||
|
|
||||||
|
elif isinstance(func, Callable):
|
||||||
|
# try to find name
|
||||||
|
for name, f in self.name2jnp.items():
|
||||||
|
if f == func:
|
||||||
|
return self._obtain_sympy_by_name(name)
|
||||||
|
raise ValueError(f"Func {func} doesn't not registered.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Func {func} need be a string or callable.")
|
||||||
|
|
||||||
|
def _obtain_sympy_by_name(self, name: str):
|
||||||
|
if name not in self.name2sympy:
|
||||||
|
raise ValueError(f"Func {name} doesn't have a sympy representation.")
|
||||||
|
return self.name2sympy[name]
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Some graph algorithm implemented in jax.
|
Some graph algorithm implemented in jax and python.
|
||||||
Only used in feed-forward networks.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from . import State
|
from . import State
|
||||||
import pickle
|
import pickle
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import numpy as np
|
|||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp, Array, jit, vmap
|
from jax import numpy as jnp, Array, jit, vmap
|
||||||
|
|
||||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
# infinite int, use to represent the unavialable index in int32 array.
|
||||||
|
# as we can not use nan in int32 array
|
||||||
|
I_INF = np.iinfo(jnp.int32).max
|
||||||
|
|
||||||
|
|
||||||
def attach_with_inf(arr, idx):
|
def attach_with_inf(arr, idx):
|
||||||
@@ -100,6 +102,9 @@ def argmin_with_mask(arr, mask):
|
|||||||
|
|
||||||
|
|
||||||
def hash_array(arr: Array):
|
def hash_array(arr: Array):
|
||||||
|
"""
|
||||||
|
Hash an array of uint32 to a single uint
|
||||||
|
"""
|
||||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||||
|
|
||||||
def update(i, hash_val):
|
def update(i, hash_val):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import jax
|
import jax
|
||||||
|
from jax import vmap, numpy as jnp
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from ..base import BaseProblem
|
from ..base import BaseProblem
|
||||||
@@ -19,7 +20,7 @@ class FuncFit(BaseProblem):
|
|||||||
|
|
||||||
def evaluate(self, state, randkey, act_func, params):
|
def evaluate(self, state, randkey, act_func, params):
|
||||||
|
|
||||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
predict = vmap(act_func, in_axes=(None, None, 0))(
|
||||||
state, params, self.inputs
|
state, params, self.inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ class FuncFit(BaseProblem):
|
|||||||
return -loss
|
return -loss
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
predict = vmap(act_func, in_axes=(None, None, 0))(
|
||||||
state, params, self.inputs
|
state, params, self.inputs
|
||||||
)
|
)
|
||||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
|
|||||||
Reference in New Issue
Block a user