remove create_func....
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Activation:
|
||||
|
||||
name2func = {}
|
||||
|
||||
@staticmethod
|
||||
@@ -89,23 +89,11 @@ class Activation:
|
||||
return z ** 3
|
||||
|
||||
|
||||
Activation.name2func = {
|
||||
'sigmoid': Activation.sigmoid_act,
|
||||
'tanh': Activation.tanh_act,
|
||||
'sin': Activation.sin_act,
|
||||
'gauss': Activation.gauss_act,
|
||||
'relu': Activation.relu_act,
|
||||
'elu': Activation.elu_act,
|
||||
'lelu': Activation.lelu_act,
|
||||
'selu': Activation.selu_act,
|
||||
'softplus': Activation.softplus_act,
|
||||
'identity': Activation.identity_act,
|
||||
'clamped': Activation.clamped_act,
|
||||
'inv': Activation.inv_act,
|
||||
'log': Activation.log_act,
|
||||
'exp': Activation.exp_act,
|
||||
'abs': Activation.abs_act,
|
||||
'hat': Activation.hat_act,
|
||||
'square': Activation.square_act,
|
||||
'cube': Activation.cube_act,
|
||||
}
|
||||
def act(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, act_funcs, z)
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user