change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

4
utils/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .activation import Activation
from .aggregation import Aggregation
from .tools import *
from .graph import *

111
utils/activation.py Normal file
View File

@@ -0,0 +1,111 @@
import jax.numpy as jnp
class Activation:
name2func = {}
@staticmethod
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu_act(z):
return jnp.maximum(z, 0)
@staticmethod
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity_act(z):
return z
@staticmethod
def clamped_act(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@staticmethod
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@staticmethod
def abs_act(z):
return jnp.abs(z)
@staticmethod
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square_act(z):
return z ** 2
@staticmethod
def cube_act(z):
return z ** 3
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}

63
utils/aggregation.py Normal file
View File

@@ -0,0 +1,63 @@
import jax.numpy as jnp
class Aggregation:
name2func = {}
@staticmethod
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median_agg(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@staticmethod
def mean_agg(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}

68
utils/graph.py Normal file
View File

@@ -0,0 +1,68 @@
"""
Some graph algorithm implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, Array, numpy as jnp
from .tools import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort!
conns: Array[N, N]
"""
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
res = jnp.full(in_degree.shape, I_INT)
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = conns[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
"""
conns = conns.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, conns)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]

69
utils/tools.py Normal file
View File

@@ -0,0 +1,69 @@
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N)
:return:
"""
N = nodes.shape[0]
CL = conns.shape[1]
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
# put all attributes include enable in res
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
return res
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from small to large. default large to small
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))