generally complete, but not work well. Debug
This commit is contained in:
@@ -5,33 +5,10 @@ from jax import random
|
||||
from jax import vmap, jit
|
||||
|
||||
|
||||
def plus1(x):
|
||||
return x + 1
|
||||
seed = jax.random.PRNGKey(42)
|
||||
seed, *subkeys = random.split(seed, 3)
|
||||
|
||||
|
||||
def minus1(x):
|
||||
return x - 1
|
||||
|
||||
|
||||
def func(rand_key, x):
|
||||
r = jax.random.uniform(rand_key, shape=())
|
||||
return jax.lax.cond(r > 0.5, plus1, minus1, x)
|
||||
|
||||
|
||||
def func2(rand_key):
|
||||
r = jax.random.uniform(rand_key, ())
|
||||
if r < 0.3:
|
||||
return 1
|
||||
elif r < 0.5:
|
||||
return 2
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
print(func(key, 0))
|
||||
|
||||
batch_func = vmap(jit(func))
|
||||
keys = random.split(key, 100)
|
||||
print(batch_func(keys, jnp.zeros(100)))
|
||||
c = random.split(seed, 1)
|
||||
print(seed, subkeys)
|
||||
print(c)
|
||||
Reference in New Issue
Block a user