change str in config (act, agg) from str to callable

This commit is contained in:
wls2002
2023-08-05 03:03:02 +08:00
parent 0e44b13291
commit af54db3b12
10 changed files with 55 additions and 101 deletions

View File

@@ -1,35 +1,4 @@
from .activation import Activation, act
from .aggregation import Aggregation, agg
from .activation import Act, act
from .aggregation import Agg, agg
from .tools import *
from .graph import *
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}
from .graph import *

View File

@@ -2,90 +2,89 @@ import jax
import jax.numpy as jnp
class Activation:
name2func = {}
class Act:
@staticmethod
def sigmoid_act(z):
def sigmoid(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh_act(z):
def tanh(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin_act(z):
def sin(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss_act(z):
def gauss(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu_act(z):
def relu(z):
return jnp.maximum(z, 0)
@staticmethod
def elu_act(z):
def elu(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu_act(z):
def lelu(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu_act(z):
def selu(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus_act(z):
def softplus(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity_act(z):
def identity(z):
return z
@staticmethod
def clamped_act(z):
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv_act(z):
def inv(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@staticmethod
def log_act(z):
def log(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp_act(z):
def exp(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@staticmethod
def abs_act(z):
def abs(z):
return jnp.abs(z)
@staticmethod
def hat_act(z):
def hat(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square_act(z):
def square(z):
return z ** 2
@staticmethod
def cube_act(z):
def cube(z):
return z ** 3

View File

@@ -2,38 +2,37 @@ import jax
import jax.numpy as jnp
class Aggregation:
name2func = {}
class Agg:
@staticmethod
def sum_agg(z):
def sum(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product_agg(z):
def product(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max_agg(z):
def max(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min_agg(z):
def min(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs_agg(z):
def maxabs(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median_agg(z):
def median(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
@@ -44,7 +43,7 @@ class Aggregation:
return median
@staticmethod
def mean_agg(z):
def mean(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)