finish jit-able speciate function
next time i'll create a new branch
This commit is contained in:
@@ -76,6 +76,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
|
||||
Reference in New Issue
Block a user