Files
tensorneat-mend/examples/jax_playground.py
wls2002 6006f92f3f finish jit-able speciate function
next time i'll create a new branch
2023-05-12 19:26:02 +08:00

48 lines
867 B
Python

import jax
import jax.numpy as jnp
from jax import jit, vmap
from time_utils import using_cprofile
from time import time
#
@jit
def fx(x, y):
return x + y
#
#
# # @jit
# def fy(z):
# z1, z2 = z, z + 1
# vmap_fx = vmap(fx)
# return vmap_fx(z1, z2)
#
# @jit
# def test_while(num, init_val):
# def cond_fun(carry):
# i, cumsum = carry
# return i < num
#
# def body_fun(carry):
# i, cumsum = carry
# cumsum += i
# return i + 1, cumsum
#
# return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
# @using_cprofile
def main():
vmap_f = vmap(fx, in_axes=(None, 0))
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
a = jnp.array([20,10,30])
b = jnp.array([6, 5, 4])
res = vmap_vmap_f(a, b)
print(res)
print(jnp.argmin(res, axis=1))
if __name__ == '__main__':
main()