做了一些时间测试:

1. jit中的vmap一个函数不会触发重新编译
2. jax.lax.while_lop单独执行确实可以提前中断,但在vmap中的性能需要考虑。
This commit is contained in:
wls2002
2023-05-12 16:42:57 +08:00
parent 5e0130e109
commit 9b56f4ff73
2 changed files with 61 additions and 2 deletions

View File

@@ -0,0 +1,4 @@
from jax import jit
@jit
def jitable_speciate():
pass

View File

@@ -1,3 +1,58 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, vmap
from time_utils import using_cprofile
from time import time
print(np.random.permutation(10))
@jit
def fx(x, y):
return x + y
# @jit
def fy(z):
z1, z2 = z, z + 1
vmap_fx = vmap(fx)
return vmap_fx(z1, z2)
@jit
def test_while(num, init_val):
def cond_fun(carry):
i, cumsum = carry
return i < num
def body_fun(carry):
i, cumsum = carry
cumsum += i
return i + 1, cumsum
return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
@using_cprofile
def main():
z = jnp.zeros((100000, ))
num = 100
nums = jnp.arange(num) * 10
f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile()
def test_time(*args):
return f(*args)
print(test_time(nums, z))
#
#
# for i in range(10):
# num = 10 ** i
# st = time()
# res = test_time(num, z)
# print(res)
# t = time() - st
# print(f"num: {num}, time: {t}")
if __name__ == '__main__':
main()