37 lines
598 B
Python
37 lines
598 B
Python
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax import random
|
|
from jax import vmap, jit
|
|
|
|
|
|
def plus1(x):
|
|
return x + 1
|
|
|
|
|
|
def minus1(x):
|
|
return x - 1
|
|
|
|
|
|
def func(rand_key, x):
|
|
r = jax.random.uniform(rand_key, shape=())
|
|
return jax.lax.cond(r > 0.5, plus1, minus1, x)
|
|
|
|
|
|
def func2(rand_key):
|
|
r = jax.random.uniform(rand_key, ())
|
|
if r < 0.3:
|
|
return 1
|
|
elif r < 0.5:
|
|
return 2
|
|
else:
|
|
return 3
|
|
|
|
|
|
|
|
key = random.PRNGKey(0)
|
|
print(func(key, 0))
|
|
|
|
batch_func = vmap(jit(func))
|
|
keys = random.split(key, 100)
|
|
print(batch_func(keys, jnp.zeros(100))) |