Current Progress: After final design presentation

This commit is contained in:
wls2002
2023-06-19 15:17:56 +08:00
parent acedd67617
commit 5cbe3c14bb
34 changed files with 533 additions and 558 deletions

View File

@@ -104,31 +104,6 @@ def cube_act(z):
return z ** 3
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
@@ -137,4 +112,3 @@ def act(idx, z):
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)