add License and pyproject.toml
This commit is contained in:
56
src/tensorneat/common/__init__.py
Normal file
56
src/tensorneat/common/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
|
||||
from .activation.act_jnp import Act, ACT_ALL, act_func
|
||||
from .aggregation.agg_sympy import *
|
||||
from .activation.act_sympy import *
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"standard_sigmoid": SympyStandardSigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"standard_tanh": SympyStandardTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
"identity": SympyIdentity,
|
||||
"inv": SympyInv,
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"sum": SympySum,
|
||||
"product": SympyProduct,
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
"square": SympySquare,
|
||||
}
|
||||
|
||||
|
||||
def convert_to_sympy(func: Union[str, Callable]):
|
||||
if isinstance(func, str):
|
||||
name = func
|
||||
else:
|
||||
name = func.__name__
|
||||
if name in name2sympy:
|
||||
return name2sympy[name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not convert to sympy! Function {name} not found in name2sympy"
|
||||
)
|
||||
|
||||
|
||||
SYMPY_FUNCS_MODULE_NP = {}
|
||||
SYMPY_FUNCS_MODULE_JNP = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||
0
src/tensorneat/common/activation/__init__.py
Normal file
0
src/tensorneat/common/activation/__init__.py
Normal file
110
src/tensorneat/common/activation/act_jnp.py
Normal file
110
src/tensorneat/common/activation/act_jnp.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
sigma_3 = 2.576
|
||||
|
||||
|
||||
class Act:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Act, name)
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z # (0, 1)
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_tanh(z):
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) # (-1, 1)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
|
||||
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def relu(z):
|
||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||
return jnp.maximum(z, 0) # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def lelu(z):
|
||||
leaky = 0.005
|
||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def identity(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp(z):
|
||||
z = jnp.clip(z, -10, 10)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def square(z):
|
||||
return jnp.pow(z, 2)
|
||||
|
||||
@staticmethod
|
||||
def abs(z):
|
||||
z = jnp.clip(z, -1, 1)
|
||||
return jnp.abs(z)
|
||||
|
||||
|
||||
ACT_ALL = (
|
||||
Act.sigmoid,
|
||||
Act.tanh,
|
||||
Act.sin,
|
||||
Act.relu,
|
||||
Act.lelu,
|
||||
Act.identity,
|
||||
Act.inv,
|
||||
Act.log,
|
||||
Act.exp,
|
||||
Act.abs,
|
||||
)
|
||||
|
||||
|
||||
def act_func(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
196
src/tensorneat/common/activation/act_sympy.py
Normal file
196
src/tensorneat/common/activation/act_sympy.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
|
||||
sigma_3 = 2.576
|
||||
|
||||
|
||||
class SympyClip(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, val, min_val, max_val):
|
||||
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
||||
return sp.Piecewise(
|
||||
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(val, min_val, max_val, backend=np):
|
||||
return backend.clip(val, min_val, max_val)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid_(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = 1 / (1 + backend.exp(-z))
|
||||
return z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardSigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3)
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
|
||||
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
|
||||
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
|
||||
class SympyRelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -sigma_3, sigma_3)
|
||||
return sp.Max(z, 0) # (0, sigma_3)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(z, -sigma_3, sigma_3)
|
||||
return backend.maximum(z, 0) # (0, sigma_3)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"relu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyLelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
leaky = 0.005
|
||||
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
leaky = 0.005
|
||||
return backend.maximum(z, leaky * z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"lelu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyIdentity(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return z
|
||||
|
||||
|
||||
class SympyInv(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
||||
return 1 / z
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"1 / {self.args[0]}"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
|
||||
|
||||
|
||||
class SympyLog(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Max(z, 1e-7)
|
||||
return sp.log(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return backend.log(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"log({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyExp(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -10, 10)
|
||||
return sp.exp(z)
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"exp({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySquare(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Pow(z, 2)
|
||||
|
||||
|
||||
class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Abs(z)
|
||||
0
src/tensorneat/common/aggregation/__init__.py
Normal file
0
src/tensorneat/common/aggregation/__init__.py
Normal file
66
src/tensorneat/common/aggregation/agg_jnp.py
Normal file
66
src/tensorneat/common/aggregation/agg_jnp.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Agg, name)
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||
|
||||
|
||||
def agg_func(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
65
src/tensorneat/common/aggregation/agg_sympy.py
Normal file
65
src/tensorneat/common/aggregation/agg_sympy.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class SympySum(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Min(*z)
|
||||
|
||||
|
||||
class SympyMaxabs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z, key=sp.Abs)
|
||||
|
||||
|
||||
class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
123
src/tensorneat/common/graph.py
Normal file
123
src/tensorneat/common/graph.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Some graph algorithm implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
from typing import Tuple, Set, List, Union
|
||||
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort!
|
||||
conns: Array[N, N]
|
||||
"""
|
||||
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INF)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
return i != I_INF
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = conns[i, :]
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
def topological_sort_python(
|
||||
nodes: Union[Set[int], List[int]],
|
||||
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||
) -> Tuple[List[int], List[List[int]]]:
|
||||
# a python version of topological_sort, use python set to store nodes and conns
|
||||
# returns the topological order of the nodes and the topological layers
|
||||
# written by gpt4 :)
|
||||
|
||||
# Make a copy of the input nodes and connections
|
||||
nodes = nodes.copy()
|
||||
conns = conns.copy()
|
||||
|
||||
# Initialize the in-degree of each node to 0
|
||||
in_degree = {node: 0 for node in nodes}
|
||||
|
||||
# Compute the in-degree for each node
|
||||
for conn in conns:
|
||||
in_degree[conn[1]] += 1
|
||||
|
||||
topo_order = []
|
||||
topo_layer = []
|
||||
|
||||
# Find all nodes with in-degree 0
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
while zero_in_degree_nodes:
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
zero_in_degree_nodes = sorted(
|
||||
zero_in_degree_nodes
|
||||
) # make sure the topo_order is from small to large
|
||||
|
||||
topo_layer.append(zero_in_degree_nodes.copy())
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
topo_order.append(node)
|
||||
|
||||
# Iterate over all connections and reduce the in-degree of connected nodes
|
||||
for conn in list(conns):
|
||||
if conn[0] == node:
|
||||
in_degree[conn[1]] -= 1
|
||||
conns.remove(conn)
|
||||
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
# Check if there are still connections left indicating a cycle
|
||||
if conns or nodes:
|
||||
raise ValueError("Graph has at least one cycle, topological sort not possible")
|
||||
|
||||
return topo_order, topo_layer
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
"""
|
||||
|
||||
conns = conns.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, conns)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
49
src/tensorneat/common/state.py
Normal file
49
src/tensorneat/common/state.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__["state_dict"] = kwargs
|
||||
|
||||
def registered_keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def register(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key in self.registered_keys():
|
||||
raise ValueError(f"Key {key} already exists in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def update(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key not in self.registered_keys():
|
||||
raise ValueError(f"Key {key} does not exist in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
raise AttributeError("State is immutable")
|
||||
|
||||
def __repr__(self):
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def __getstate__(self):
|
||||
return self.state_dict.copy()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__["state_dict"] = state
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.state_dict
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
69
src/tensorneat/common/stateful_class.py
Normal file
69
src/tensorneat/common/stateful_class.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from . import State
|
||||
import pickle
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
|
||||
class StatefulBaseClass:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def save(self, state: Optional[State] = None, path: Optional[str] = None):
|
||||
if path is None:
|
||||
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
path = f"./{self.__class__.__name__} {time}.pkl"
|
||||
if state is not None:
|
||||
self.__dict__["aux_for_state"] = state
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
# only pickle the picklable attributes
|
||||
state = self.__dict__.copy()
|
||||
non_picklable_keys = []
|
||||
for key, value in state.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
except Exception:
|
||||
non_picklable_keys.append(key)
|
||||
|
||||
for key in non_picklable_keys:
|
||||
state.pop(key)
|
||||
|
||||
return state
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, StatefulBaseClass):
|
||||
config[str(key)] = value.show_config()
|
||||
else:
|
||||
config[str(key)] = str(value)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
if with_state:
|
||||
if "aux_for_state" not in obj.__dict__:
|
||||
if warning:
|
||||
warnings.warn(
|
||||
"This object does not have state to load, return empty state",
|
||||
category=UserWarning,
|
||||
)
|
||||
return obj, State()
|
||||
state = obj.__dict__["aux_for_state"]
|
||||
del obj.__dict__["aux_for_state"]
|
||||
return obj, state
|
||||
else:
|
||||
if "aux_for_state" in obj.__dict__:
|
||||
if warning:
|
||||
warnings.warn(
|
||||
"This object has state to load, ignore it",
|
||||
category=UserWarning,
|
||||
)
|
||||
del obj.__dict__["aux_for_state"]
|
||||
return obj
|
||||
110
src/tensorneat/common/tools.py
Normal file
110
src/tensorneat/common/tools.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
|
||||
def attach_with_inf(arr, idx):
|
||||
target_dim = arr.ndim + idx.ndim - 1
|
||||
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(randkey, mask, default=I_INF) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["reverse"])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float(
|
||||
randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate
|
||||
):
|
||||
"""
|
||||
mutate a float value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, mutate_rate) -> add noise
|
||||
r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val),
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int(randkey, val, options, replace_rate):
|
||||
"""
|
||||
mutate an int value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
"""
|
||||
find the index of the minimum element in the array, but only consider the element with True mask
|
||||
"""
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
Reference in New Issue
Block a user