finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -104,11 +104,23 @@ def cube_act(z):
return z ** 3
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
act_name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}