initial commit
This commit is contained in:
37
examples/jax_playground.py
Normal file
37
examples/jax_playground.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import random
|
||||
from jax import vmap, jit
|
||||
|
||||
|
||||
def plus1(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def minus1(x):
|
||||
return x - 1
|
||||
|
||||
|
||||
def func(rand_key, x):
|
||||
r = jax.random.uniform(rand_key, shape=())
|
||||
return jax.lax.cond(r > 0.5, plus1, minus1, x)
|
||||
|
||||
|
||||
def func2(rand_key):
|
||||
r = jax.random.uniform(rand_key, ())
|
||||
if r < 0.3:
|
||||
return 1
|
||||
elif r < 0.5:
|
||||
return 2
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
print(func(key, 0))
|
||||
|
||||
batch_func = vmap(jit(func))
|
||||
keys = random.split(key, 100)
|
||||
print(batch_func(keys, jnp.zeros(100)))
|
||||
Reference in New Issue
Block a user