debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -134,5 +134,3 @@ def act(idx, z):
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
vectorized_act = jax.vmap(act, in_axes=(0, 0))