做了一些时间测试:
1. jit中的vmap一个函数不会触发重新编译 2. jax.lax.while_lop单独执行确实可以提前中断,但在vmap中的性能需要考虑。
This commit is contained in:
4
algorithms/neat/jitable_speciate.py
Normal file
4
algorithms/neat/jitable_speciate.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from jax import jit
|
||||||
|
@jit
|
||||||
|
def jitable_speciate():
|
||||||
|
pass
|
||||||
@@ -1,3 +1,58 @@
|
|||||||
import numpy as np
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit, vmap
|
||||||
|
from time_utils import using_cprofile
|
||||||
|
from time import time
|
||||||
|
|
||||||
print(np.random.permutation(10))
|
@jit
|
||||||
|
def fx(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
# @jit
|
||||||
|
def fy(z):
|
||||||
|
z1, z2 = z, z + 1
|
||||||
|
vmap_fx = vmap(fx)
|
||||||
|
return vmap_fx(z1, z2)
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def test_while(num, init_val):
|
||||||
|
def cond_fun(carry):
|
||||||
|
i, cumsum = carry
|
||||||
|
return i < num
|
||||||
|
|
||||||
|
def body_fun(carry):
|
||||||
|
i, cumsum = carry
|
||||||
|
cumsum += i
|
||||||
|
return i + 1, cumsum
|
||||||
|
|
||||||
|
return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@using_cprofile
|
||||||
|
def main():
|
||||||
|
z = jnp.zeros((100000, ))
|
||||||
|
|
||||||
|
num = 100
|
||||||
|
|
||||||
|
nums = jnp.arange(num) * 10
|
||||||
|
|
||||||
|
f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile()
|
||||||
|
def test_time(*args):
|
||||||
|
return f(*args)
|
||||||
|
|
||||||
|
print(test_time(nums, z))
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# for i in range(10):
|
||||||
|
# num = 10 ** i
|
||||||
|
# st = time()
|
||||||
|
# res = test_time(num, z)
|
||||||
|
# print(res)
|
||||||
|
# t = time() - st
|
||||||
|
# print(f"num: {num}, time: {t}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user