debug-branch
This commit is contained in:
@@ -117,10 +117,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
return fetch_first(cumsum >= target, default)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
print(fetch_first(a > 3))
|
||||
print(fetch_first(a > 30))
|
||||
@@ -129,6 +131,9 @@ if __name__ == '__main__':
|
||||
print(fetch_last(a > 30))
|
||||
|
||||
rand_key = jax.random.PRNGKey(0)
|
||||
for _ in range(100):
|
||||
rand_key, _ = jax.random.split(rand_key)
|
||||
print(fetch_random(rand_key, a > 0))
|
||||
|
||||
for t in [-1, 0, 1, 2, 3, 4, 5]:
|
||||
for _ in range(10):
|
||||
rand_key, _ = jax.random.split(rand_key)
|
||||
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
|
||||
print(t, fetch_random(rand_key, a > t))
|
||||
|
||||
Reference in New Issue
Block a user