fix bugs
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
@@ -15,7 +13,9 @@ from typing import Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"standard_sigmoid": SympyStandardSigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"standard_tanh": SympyStandardTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
|
||||
@@ -12,19 +12,26 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z # (0, 1)
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_tanh(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z =5 * z / sigma_3
|
||||
return jnp.tanh(z) # (-1, 1)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -25,21 +25,16 @@ class SympyClip(sp.Function):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
class SympySigmoid_(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z * sigma_3
|
||||
return None
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 1 / (1 + backend.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
return z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
@@ -48,32 +43,47 @@ class SympySigmoid(sp.Function):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardSigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3)
|
||||
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# z = 1 / (1 + backend.exp(-z))
|
||||
#
|
||||
# return z # (0, 1)
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
return sp.tanh(z) * sigma_3
|
||||
return None
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z) * sigma_3
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
|
||||
class SympyStandardTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
return sp.tanh(z)
|
||||
return None
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
return backend.tanh(z) # (-1, 1)
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# return backend.tanh(z) # (-1, 1)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
|
||||
@@ -9,19 +9,19 @@ class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
|
||||
@@ -36,6 +36,9 @@ class State:
|
||||
def __setstate__(self, state):
|
||||
self.__dict__["state_dict"] = state
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.state_dict
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
|
||||
@@ -19,6 +19,21 @@ class StatefulBaseClass:
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
# only pickle the picklable attributes
|
||||
state = self.__dict__.copy()
|
||||
non_picklable_keys = []
|
||||
for key, value in state.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
except Exception:
|
||||
non_picklable_keys.append(key)
|
||||
|
||||
for key in non_picklable_keys:
|
||||
state.pop(key)
|
||||
|
||||
return state
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
|
||||
@@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns):
|
||||
return unflatten
|
||||
|
||||
|
||||
# TODO: strange implementation
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
@@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
|
||||
Reference in New Issue
Block a user