This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -4,9 +4,10 @@ from jax import jit, vmap
from time_utils import using_cprofile
from time import time
#
import numpy as np
@jit
def fx(x, y):
return x + y
def fx(x):
return jnp.arange(x, x + 10)
#
#
# # @jit
@@ -33,13 +34,15 @@ def fx(x, y):
# @using_cprofile
def main():
vmap_f = vmap(fx, in_axes=(None, 0))
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
a = jnp.array([20,10,30])
b = jnp.array([6, 5, 4])
res = vmap_vmap_f(a, b)
print(res)
print(jnp.argmin(res, axis=1))
print(fx(1))
# vmap_f = vmap(fx, in_axes=(None, 0))
# vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
# a = jnp.array([20,10,30])
# b = jnp.array([6, 5, 4])
# res = vmap_vmap_f(a, b)
# print(res)
# print(jnp.argmin(res, axis=1))