debug-branch
This commit is contained in:
@@ -134,5 +134,3 @@ def act(idx, z):
|
||||
# change idx from float to int
|
||||
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
|
||||
|
||||
vectorized_act = jax.vmap(act, in_axes=(0, 0))
|
||||
|
||||
Reference in New Issue
Block a user