debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -117,10 +117,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
return fetch_first(cumsum >= target, default)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
@@ -129,6 +131,9 @@ if __name__ == '__main__':
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for _ in range(100):
rand_key, _ = jax.random.split(rand_key)
print(fetch_random(rand_key, a > 0))
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))