finish all refactoring
This commit is contained in:
@@ -6,48 +6,26 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def gauss(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
@staticmethod
|
||||
def relu(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
@staticmethod
|
||||
def elu(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
@staticmethod
|
||||
def lelu(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def selu(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
@staticmethod
|
||||
def softplus(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
@staticmethod
|
||||
def identity(z):
|
||||
return z
|
||||
@@ -58,7 +36,11 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
z = jnp.where(
|
||||
z > 0,
|
||||
jnp.maximum(z, 1e-7),
|
||||
jnp.minimum(z, -1e-7)
|
||||
)
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
@@ -68,24 +50,27 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def exp(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
z = jnp.clip(z, -10, 10)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def abs(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
@staticmethod
|
||||
def hat(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
@staticmethod
|
||||
def square(z):
|
||||
return z ** 2
|
||||
|
||||
@staticmethod
|
||||
def cube(z):
|
||||
return z ** 3
|
||||
ACT_ALL = (
|
||||
Act.sigmoid,
|
||||
Act.tanh,
|
||||
Act.sin,
|
||||
Act.relu,
|
||||
Act.lelu,
|
||||
Act.identity,
|
||||
Act.clamped,
|
||||
Act.inv,
|
||||
Act.log,
|
||||
Act.exp,
|
||||
Act.abs,
|
||||
)
|
||||
|
||||
|
||||
def act(idx, z, act_funcs):
|
||||
|
||||
@@ -51,6 +51,9 @@ class Agg:
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||
|
||||
|
||||
def agg(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
|
||||
Reference in New Issue
Block a user