Files
tensorneat-mend/examples/jax_playground.py
2023-05-08 15:58:09 +08:00

21 lines
338 B
Python

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
from functools import partial
from examples.time_utils import using_cprofile
@jit
def func(x, y):
return x + y
a, b, c = jnp.array([1]), jnp.array([2]), jnp.array([3])
li = [a, b, c]
cpu_li = jax.device_get(li)
print(cpu_li)