prepare for experiment

This commit is contained in:
wls2002
2023-05-14 15:27:17 +08:00
parent 72c9d4167a
commit 2b79f2c903
11 changed files with 252 additions and 62 deletions

View File

@@ -133,5 +133,8 @@ act_name2key = {
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)