diff --git a/src/tensorneat/common/activation/__init__.py b/src/tensorneat/common/activation/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/tensorneat/common/activation/act_jnp.py b/src/tensorneat/common/activation/act_jnp.py deleted file mode 100644 index 790152e..0000000 --- a/src/tensorneat/common/activation/act_jnp.py +++ /dev/null @@ -1,110 +0,0 @@ -import jax -import jax.numpy as jnp - - -sigma_3 = 2.576 - - -class Act: - @staticmethod - def name2func(name): - return getattr(Act, name) - - @staticmethod - def sigmoid(z): - z = 5 * z / sigma_3 - z = 1 / (1 + jnp.exp(-z)) - - return z * sigma_3 # (0, sigma_3) - - @staticmethod - def standard_sigmoid(z): - z = 5 * z / sigma_3 - z = 1 / (1 + jnp.exp(-z)) - - return z # (0, 1) - - @staticmethod - def tanh(z): - z = 5 * z / sigma_3 - return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3) - - @staticmethod - def standard_tanh(z): - z = 5 * z / sigma_3 - return jnp.tanh(z) # (-1, 1) - - @staticmethod - def sin(z): - z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2) - return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3) - - @staticmethod - def relu(z): - z = jnp.clip(z, -sigma_3, sigma_3) - return jnp.maximum(z, 0) # (0, sigma_3) - - @staticmethod - def lelu(z): - leaky = 0.005 - z = jnp.clip(z, -sigma_3, sigma_3) - return jnp.where(z > 0, z, leaky * z) - - @staticmethod - def identity(z): - return z - - @staticmethod - def inv(z): - z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7)) - return 1 / z - - @staticmethod - def log(z): - z = jnp.maximum(z, 1e-7) - return jnp.log(z) - - @staticmethod - def exp(z): - z = jnp.clip(z, -10, 10) - return jnp.exp(z) - - @staticmethod - def square(z): - return jnp.pow(z, 2) - - @staticmethod - def abs(z): - z = jnp.clip(z, -1, 1) - return jnp.abs(z) - - -ACT_ALL = ( - Act.sigmoid, - Act.tanh, - Act.sin, - Act.relu, - Act.lelu, - Act.identity, - Act.inv, - Act.log, - Act.exp, - Act.abs, -) - - -def act_func(idx, z, act_funcs): - """ - calculate activation function for each node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - - # -1 means identity activation - res = jax.lax.cond( - idx == -1, - lambda: z, - lambda: jax.lax.switch(idx, act_funcs, z), - ) - - return res diff --git a/src/tensorneat/common/activation/act_sympy.py b/src/tensorneat/common/activation/act_sympy.py deleted file mode 100644 index c7055a7..0000000 --- a/src/tensorneat/common/activation/act_sympy.py +++ /dev/null @@ -1,196 +0,0 @@ -import sympy as sp -import numpy as np - - -sigma_3 = 2.576 - - -class SympyClip(sp.Function): - @classmethod - def eval(cls, val, min_val, max_val): - if val.is_Number and min_val.is_Number and max_val.is_Number: - return sp.Piecewise( - (min_val, val < min_val), (max_val, val > max_val), (val, True) - ) - return None - - @staticmethod - def numerical_eval(val, min_val, max_val, backend=np): - return backend.clip(val, min_val, max_val) - - def _sympystr(self, printer): - return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})" - - def _latex(self, printer): - return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)" - - -class SympySigmoid_(sp.Function): - @classmethod - def eval(cls, z): - z = 1 / (1 + sp.exp(-z)) - return z - - @staticmethod - def numerical_eval(z, backend=np): - z = 1 / (1 + backend.exp(-z)) - return z - - def _sympystr(self, printer): - return f"sigmoid({self.args[0]})" - - def _latex(self, printer): - return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)" - - -class SympySigmoid(sp.Function): - @classmethod - def eval(cls, z): - return SympySigmoid_(5 * z / sigma_3) * sigma_3 - - -class SympyStandardSigmoid(sp.Function): - @classmethod - def eval(cls, z): - return SympySigmoid_(5 * z / sigma_3) - - -class SympyTanh(sp.Function): - @classmethod - def eval(cls, z): - z = 5 * z / sigma_3 - return sp.tanh(z) * sigma_3 - - -class SympyStandardTanh(sp.Function): - @classmethod - def eval(cls, z): - z = 5 * z / sigma_3 - return sp.tanh(z) - - -class SympySin(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2) - return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3) - return None - - @staticmethod - def numerical_eval(z, backend=np): - z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2) - return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3) - - -class SympyRelu(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - z = SympyClip(z, -sigma_3, sigma_3) - return sp.Max(z, 0) # (0, sigma_3) - return None - - @staticmethod - def numerical_eval(z, backend=np): - z = backend.clip(z, -sigma_3, sigma_3) - return backend.maximum(z, 0) # (0, sigma_3) - - def _sympystr(self, printer): - return f"relu({self.args[0]})" - - def _latex(self, printer): - return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)" - - -class SympyLelu(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - leaky = 0.005 - return sp.Piecewise((z, z > 0), (leaky * z, True)) - return None - - @staticmethod - def numerical_eval(z, backend=np): - leaky = 0.005 - return backend.maximum(z, leaky * z) - - def _sympystr(self, printer): - return f"lelu({self.args[0]})" - - def _latex(self, printer): - return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)" - - -class SympyIdentity(sp.Function): - @classmethod - def eval(cls, z): - return z - - -class SympyInv(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True)) - return 1 / z - return None - - @staticmethod - def numerical_eval(z, backend=np): - z = backend.maximum(z, 1e-7) - return 1 / z - - def _sympystr(self, printer): - return f"1 / {self.args[0]}" - - def _latex(self, printer): - return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}" - - -class SympyLog(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - z = sp.Max(z, 1e-7) - return sp.log(z) - return None - - @staticmethod - def numerical_eval(z, backend=np): - z = backend.maximum(z, 1e-7) - return backend.log(z) - - def _sympystr(self, printer): - return f"log({self.args[0]})" - - def _latex(self, printer): - return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)" - - -class SympyExp(sp.Function): - @classmethod - def eval(cls, z): - if z.is_Number: - z = SympyClip(z, -10, 10) - return sp.exp(z) - return None - - def _sympystr(self, printer): - return f"exp({self.args[0]})" - - def _latex(self, printer): - return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)" - - -class SympySquare(sp.Function): - @classmethod - def eval(cls, z): - return sp.Pow(z, 2) - - -class SympyAbs(sp.Function): - @classmethod - def eval(cls, z): - return sp.Abs(z) diff --git a/src/tensorneat/common/aggregation/__init__.py b/src/tensorneat/common/aggregation/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/tensorneat/common/aggregation/agg_jnp.py b/src/tensorneat/common/aggregation/agg_jnp.py deleted file mode 100644 index 1be1ef0..0000000 --- a/src/tensorneat/common/aggregation/agg_jnp.py +++ /dev/null @@ -1,66 +0,0 @@ -import jax -import jax.numpy as jnp - - -class Agg: - @staticmethod - def name2func(name): - return getattr(Agg, name) - - @staticmethod - def sum(z): - return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0) - - @staticmethod - def product(z): - return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1) - - @staticmethod - def max(z): - return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf) - - @staticmethod - def min(z): - return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf) - - @staticmethod - def maxabs(z): - z = jnp.where(jnp.isnan(z), 0, z) - abs_z = jnp.abs(z) - max_abs_index = jnp.argmax(abs_z) - return z[max_abs_index] - - @staticmethod - def median(z): - n = jnp.sum(~jnp.isnan(z), axis=0) - - z = jnp.sort(z) # sort - - idx1, idx2 = (n - 1) // 2, n // 2 - median = (z[idx1] + z[idx2]) / 2 - - return median - - @staticmethod - def mean(z): - aux = jnp.where(jnp.isnan(z), 0, z) - valid_values_sum = jnp.sum(aux, axis=0) - valid_values_count = jnp.sum(~jnp.isnan(z), axis=0) - mean_without_zeros = valid_values_sum / valid_values_count - return mean_without_zeros - - -AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean) - - -def agg_func(idx, z, agg_funcs): - """ - calculate activation function for inputs of node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - - return jax.lax.cond( - jnp.all(jnp.isnan(z)), - lambda: jnp.nan, # all inputs are nan - lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise - ) diff --git a/src/tensorneat/common/functions/__init__.py b/src/tensorneat/common/functions/__init__.py new file mode 100644 index 0000000..6e23462 --- /dev/null +++ b/src/tensorneat/common/functions/__init__.py @@ -0,0 +1,58 @@ +from .act_jnp import * +from .act_sympy import * +from .agg_jnp import * +from .agg_sympy import * +from .manager import FunctionManager + +act_name2jnp = { + "scaled_sigmoid": scaled_sigmoid_, + "sigmoid": sigmoid_, + "scaled_tanh": scaled_tanh_, + "tanh": tanh_, + "sin": sin_, + "relu": relu_, + "lelu": lelu_, + "identity": identity_, + "inv": inv_, + "log": log_, + "exp": exp_, + "abs": abs_, +} + +act_name2sympy = { + "scaled_sigmoid": SympyScaledSigmoid, + "sigmoid": SympySigmoid, + "scaled_tanh": SympyScaledTanh, + "tanh": SympyTanh, + "sin": SympySin, + "relu": SympyRelu, + "lelu": SympyLelu, + "identity": SympyIdentity, + "inv": SympyIdentity, + "log": SympyLog, + "exp": SympyExp, + "abs": SympyAbs, +} + +agg_name2jnp = { + "sum": sum_, + "product": product_, + "max": max_, + "min": min_, + "maxabs": maxabs_, + "median": median_, + "mean": mean_, +} + +agg_name2sympy = { + "sum": SympySum, + "product": SympyProduct, + "max": SympyMax, + "min": SympyMin, + "maxabs": SympyMaxabs, + "median": SympyMedian, + "mean": SympyMean, +} + +ACT = FunctionManager(act_name2jnp, act_name2sympy) +AGG = FunctionManager(agg_name2jnp, agg_name2sympy) diff --git a/src/tensorneat/common/functions/act_jnp.py b/src/tensorneat/common/functions/act_jnp.py new file mode 100644 index 0000000..75b0e86 --- /dev/null +++ b/src/tensorneat/common/functions/act_jnp.py @@ -0,0 +1,57 @@ +import jax.numpy as jnp + +SCALE = 5 + + +def scaled_sigmoid_(z): + z = 1 / (1 + jnp.exp(-z)) + return z * SCALE + + +def sigmoid_(z): + z = 1 / (1 + jnp.exp(-z)) + return z + + +def scaled_tanh_(z): + return jnp.tanh(z) * SCALE + + +def tanh_(z): + return jnp.tanh(z) + + +def sin_(z): + return jnp.sin(z) + + +def relu_(z): + return jnp.maximum(z, 0) + + +def lelu_(z): + leaky = 0.005 + return jnp.where(z > 0, z, leaky * z) + + +def identity_(z): + return z + + +def inv_(z): + # avoid division by zero + z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7)) + return 1 / z + + +def log_(z): + z = jnp.maximum(z, 1e-7) + return jnp.log(z) + + +def exp_(z): + return jnp.exp(z) + + +def abs_(z): + return jnp.abs(z) diff --git a/src/tensorneat/common/functions/act_sympy.py b/src/tensorneat/common/functions/act_sympy.py new file mode 100644 index 0000000..1c80594 --- /dev/null +++ b/src/tensorneat/common/functions/act_sympy.py @@ -0,0 +1,100 @@ +import sympy as sp +import numpy as np + +SCALE = 5 + +class SympySigmoid(sp.Function): + @classmethod + def eval(cls, z): + z = 1 / (1 + sp.exp(-z)) + return z + + +class SympyScaledSigmoid(sp.Function): + @classmethod + def eval(cls, z): + return SympySigmoid(z) * SCALE + + +class SympyTanh(sp.Function): + @classmethod + def eval(cls, z): + return sp.tanh(z) + + +class SympyScaledTanh(sp.Function): + @classmethod + def eval(cls, z): + return SympyTanh(z) * SCALE + + +class SympySin(sp.Function): + @classmethod + def eval(cls, z): + return sp.sin(z) + + +class SympyRelu(sp.Function): + @classmethod + def eval(cls, z): + return sp.Max(z, 0) + + +class SympyLelu(sp.Function): + @classmethod + def eval(cls, z): + leaky = 0.005 + return sp.Piecewise((z, z > 0), (leaky * z, True)) + + +class SympyIdentity(sp.Function): + @classmethod + def eval(cls, z): + return z + + +class SympyInv(sp.Function): + @classmethod + def eval(cls, z): + z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True)) + return 1 / z + + +class SympyLog(sp.Function): + @classmethod + def eval(cls, z): + z = sp.Max(z, 1e-7) + return sp.log(z) + + +class SympyExp(sp.Function): + @classmethod + def eval(cls, z): + z = SympyClip(z, -10, 10) + return sp.exp(z) + + +class SympyAbs(sp.Function): + @classmethod + def eval(cls, z): + return sp.Abs(z) + + +class SympyClip(sp.Function): + @classmethod + def eval(cls, val, min_val, max_val): + if val.is_Number and min_val.is_Number and max_val.is_Number: + return sp.Piecewise( + (min_val, val < min_val), (max_val, val > max_val), (val, True) + ) + return None + + @staticmethod + def numerical_eval(val, min_val, max_val, backend=np): + return backend.clip(val, min_val, max_val) + + def _sympystr(self, printer): + return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})" + + def _latex(self, printer): + return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)" diff --git a/src/tensorneat/common/functions/agg_jnp.py b/src/tensorneat/common/functions/agg_jnp.py new file mode 100644 index 0000000..53ca931 --- /dev/null +++ b/src/tensorneat/common/functions/agg_jnp.py @@ -0,0 +1,41 @@ +import jax.numpy as jnp + + +def sum_(z): + return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0) + + +def product_(z): + return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1) + + +def max_(z): + return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf) + + +def min_(z): + return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf) + + +def maxabs_(z): + z = jnp.where(jnp.isnan(z), 0, z) + abs_z = jnp.abs(z) + max_abs_index = jnp.argmax(abs_z) + return z[max_abs_index] + + +def median_(z): + n = jnp.sum(~jnp.isnan(z), axis=0) + + z = jnp.sort(z) # sort + + idx1, idx2 = (n - 1) // 2, n // 2 + median = (z[idx1] + z[idx2]) / 2 + + return median + + +def mean_(z): + sumation = sum_(z) + valid_count = jnp.sum(~jnp.isnan(z), axis=0) + return sumation / valid_count diff --git a/src/tensorneat/common/aggregation/agg_sympy.py b/src/tensorneat/common/functions/agg_sympy.py similarity index 87% rename from src/tensorneat/common/aggregation/agg_sympy.py rename to src/tensorneat/common/functions/agg_sympy.py index 0890a49..85df179 100644 --- a/src/tensorneat/common/aggregation/agg_sympy.py +++ b/src/tensorneat/common/functions/agg_sympy.py @@ -17,12 +17,20 @@ class SympyProduct(sp.Function): def eval(cls, z): return sp.Mul(*z) + @classmethod + def numerical_eval(cls, z, backend=np): + return backend.product(z) + class SympyMax(sp.Function): @classmethod def eval(cls, z): return sp.Max(*z) + @classmethod + def numerical_eval(cls, z, backend=np): + return backend.max(z) + class SympyMin(sp.Function): @classmethod diff --git a/src/tensorneat/common/functions/manager.py b/src/tensorneat/common/functions/manager.py new file mode 100644 index 0000000..cc8b2cc --- /dev/null +++ b/src/tensorneat/common/functions/manager.py @@ -0,0 +1,49 @@ +from typing import Union, Callable +import sympy as sp + +class FunctionManager: + + def __init__(self, name2jnp, name2sympy): + self.name2jnp = name2jnp + self.name2sympy = name2sympy + + def get_all_funcs(self): + all_funcs = [] + for name in self.names: + all_funcs.append(getattr(self, name)) + return all_funcs + + def __getattribute__(self, name: str): + return self.name2jnp[name] + + def add_func(self, name, func): + if not callable(func): + raise ValueError("The provided function is not callable") + if name in self.names: + raise ValueError(f"The provided name={name} is already in use") + + self.name2jnp[name] = func + + def update_sympy(self, name, sympy_cls: sp.Function): + self.name2sympy[name] = sympy_cls + + def obtain_sympy(self, func: Union[str, Callable]): + if isinstance(func, str): + if func not in self.name2sympy: + raise ValueError(f"Func {func} doesn't have a sympy representation.") + return self.name2sympy[func] + + elif isinstance(func, Callable): + # try to find name + for name, f in self.name2jnp.items(): + if f == func: + return self._obtain_sympy_by_name(name) + raise ValueError(f"Func {func} doesn't not registered.") + + else: + raise ValueError(f"Func {func} need be a string or callable.") + + def _obtain_sympy_by_name(self, name: str): + if name not in self.name2sympy: + raise ValueError(f"Func {name} doesn't have a sympy representation.") + return self.name2sympy[name] diff --git a/src/tensorneat/common/graph.py b/src/tensorneat/common/graph.py index 533bd02..f783b8e 100644 --- a/src/tensorneat/common/graph.py +++ b/src/tensorneat/common/graph.py @@ -1,6 +1,5 @@ """ -Some graph algorithm implemented in jax. -Only used in feed-forward networks. +Some graph algorithm implemented in jax and python. """ import jax diff --git a/src/tensorneat/common/stateful_class.py b/src/tensorneat/common/stateful_class.py index 7646493..5a87f30 100644 --- a/src/tensorneat/common/stateful_class.py +++ b/src/tensorneat/common/stateful_class.py @@ -1,4 +1,3 @@ -import json from typing import Optional from . import State import pickle diff --git a/src/tensorneat/common/tools.py b/src/tensorneat/common/tools.py index ce27176..b3d3d0c 100644 --- a/src/tensorneat/common/tools.py +++ b/src/tensorneat/common/tools.py @@ -4,7 +4,9 @@ import numpy as np import jax from jax import numpy as jnp, Array, jit, vmap -I_INF = np.iinfo(jnp.int32).max # infinite int +# infinite int, use to represent the unavialable index in int32 array. +# as we can not use nan in int32 array +I_INF = np.iinfo(jnp.int32).max def attach_with_inf(arr, idx): @@ -100,6 +102,9 @@ def argmin_with_mask(arr, mask): def hash_array(arr: Array): + """ + Hash an array of uint32 to a single uint + """ arr = jax.lax.bitcast_convert_type(arr, jnp.uint32) def update(i, hash_val): diff --git a/src/tensorneat/problem/func_fit/func_fit.py b/src/tensorneat/problem/func_fit/func_fit.py index 70730e6..719c7ea 100644 --- a/src/tensorneat/problem/func_fit/func_fit.py +++ b/src/tensorneat/problem/func_fit/func_fit.py @@ -1,4 +1,5 @@ import jax +from jax import vmap, numpy as jnp import jax.numpy as jnp from ..base import BaseProblem @@ -19,7 +20,7 @@ class FuncFit(BaseProblem): def evaluate(self, state, randkey, act_func, params): - predict = jax.vmap(act_func, in_axes=(None, None, 0))( + predict = vmap(act_func, in_axes=(None, None, 0))( state, params, self.inputs ) @@ -41,7 +42,7 @@ class FuncFit(BaseProblem): return -loss def show(self, state, randkey, act_func, params, *args, **kwargs): - predict = jax.vmap(act_func, in_axes=(None, None, 0))( + predict = vmap(act_func, in_axes=(None, None, 0))( state, params, self.inputs ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])