FAST!
This commit is contained in:
@@ -4,9 +4,10 @@ from jax import jit, vmap
|
||||
from time_utils import using_cprofile
|
||||
from time import time
|
||||
#
|
||||
import numpy as np
|
||||
@jit
|
||||
def fx(x, y):
|
||||
return x + y
|
||||
def fx(x):
|
||||
return jnp.arange(x, x + 10)
|
||||
#
|
||||
#
|
||||
# # @jit
|
||||
@@ -33,13 +34,15 @@ def fx(x, y):
|
||||
|
||||
# @using_cprofile
|
||||
def main():
|
||||
vmap_f = vmap(fx, in_axes=(None, 0))
|
||||
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
|
||||
a = jnp.array([20,10,30])
|
||||
b = jnp.array([6, 5, 4])
|
||||
res = vmap_vmap_f(a, b)
|
||||
print(res)
|
||||
print(jnp.argmin(res, axis=1))
|
||||
print(fx(1))
|
||||
|
||||
# vmap_f = vmap(fx, in_axes=(None, 0))
|
||||
# vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
|
||||
# a = jnp.array([20,10,30])
|
||||
# b = jnp.array([6, 5, 4])
|
||||
# res = vmap_vmap_f(a, b)
|
||||
# print(res)
|
||||
# print(jnp.argmin(res, axis=1))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user