finish jit-able speciate function

next time i'll create a new branch
This commit is contained in:
wls2002
2023-05-12 19:26:02 +08:00
parent 9b56f4ff73
commit 6006f92f3f
6 changed files with 212 additions and 54 deletions

View File

@@ -76,6 +76,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])