finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -104,11 +104,23 @@ def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
|
||||
@jit
|
||||
def act(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
||||
|
||||
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
act_name2func = {
|
||||
'sigmoid': sigmoid_act,
|
||||
'tanh': tanh_act,
|
||||
'sin': sin_act,
|
||||
'gauss': gauss_act,
|
||||
'relu': relu_act,
|
||||
'elu': elu_act,
|
||||
'lelu': lelu_act,
|
||||
'selu': selu_act,
|
||||
'softplus': softplus_act,
|
||||
'identity': identity_act,
|
||||
'clamped': clamped_act,
|
||||
'inv': inv_act,
|
||||
'log': log_act,
|
||||
'exp': exp_act,
|
||||
'abs': abs_act,
|
||||
'hat': hat_act,
|
||||
'square': square_act,
|
||||
'cube': cube_act,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user