From 9b56f4ff73a2972ece5b11eaf13c1330e72978b1 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 12 May 2023 16:42:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=81=9A=E4=BA=86=E4=B8=80=E4=BA=9B=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E6=B5=8B=E8=AF=95=EF=BC=9A=201.=20jit=E4=B8=AD?= =?UTF-8?q?=E7=9A=84vmap=E4=B8=80=E4=B8=AA=E5=87=BD=E6=95=B0=E4=B8=8D?= =?UTF-8?q?=E4=BC=9A=E8=A7=A6=E5=8F=91=E9=87=8D=E6=96=B0=E7=BC=96=E8=AF=91?= =?UTF-8?q?=202.=20jax.lax.while=5Flop=E5=8D=95=E7=8B=AC=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E7=A1=AE=E5=AE=9E=E5=8F=AF=E4=BB=A5=E6=8F=90=E5=89=8D=E4=B8=AD?= =?UTF-8?q?=E6=96=AD=EF=BC=8C=E4=BD=86=E5=9C=A8vmap=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E9=9C=80=E8=A6=81=E8=80=83=E8=99=91=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- algorithms/neat/jitable_speciate.py | 4 ++ examples/jax_playground.py | 59 ++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 algorithms/neat/jitable_speciate.py diff --git a/algorithms/neat/jitable_speciate.py b/algorithms/neat/jitable_speciate.py new file mode 100644 index 0000000..c9b16c4 --- /dev/null +++ b/algorithms/neat/jitable_speciate.py @@ -0,0 +1,4 @@ +from jax import jit +@jit +def jitable_speciate(): + pass \ No newline at end of file diff --git a/examples/jax_playground.py b/examples/jax_playground.py index a3bbcbc..226532e 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,3 +1,58 @@ -import numpy as np +import jax +import jax.numpy as jnp +from jax import jit, vmap +from time_utils import using_cprofile +from time import time -print(np.random.permutation(10)) \ No newline at end of file +@jit +def fx(x, y): + return x + y + + +# @jit +def fy(z): + z1, z2 = z, z + 1 + vmap_fx = vmap(fx) + return vmap_fx(z1, z2) + +@jit +def test_while(num, init_val): + def cond_fun(carry): + i, cumsum = carry + return i < num + + def body_fun(carry): + i, cumsum = carry + cumsum += i + return i + 1, cumsum + + return jax.lax.while_loop(cond_fun, body_fun, (0, init_val)) + + + +@using_cprofile +def main(): + z = jnp.zeros((100000, )) + + num = 100 + + nums = jnp.arange(num) * 10 + + f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile() + def test_time(*args): + return f(*args) + + print(test_time(nums, z)) + + # + # + # for i in range(10): + # num = 10 ** i + # st = time() + # res = test_time(num, z) + # print(res) + # t = time() - st + # print(f"num: {num}, time: {t}") + +if __name__ == '__main__': + main()