diff --git a/algorithms/neat/jitable_speciate.py b/algorithms/neat/jitable_speciate.py new file mode 100644 index 0000000..c9b16c4 --- /dev/null +++ b/algorithms/neat/jitable_speciate.py @@ -0,0 +1,4 @@ +from jax import jit +@jit +def jitable_speciate(): + pass \ No newline at end of file diff --git a/examples/jax_playground.py b/examples/jax_playground.py index a3bbcbc..226532e 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,3 +1,58 @@ -import numpy as np +import jax +import jax.numpy as jnp +from jax import jit, vmap +from time_utils import using_cprofile +from time import time -print(np.random.permutation(10)) \ No newline at end of file +@jit +def fx(x, y): + return x + y + + +# @jit +def fy(z): + z1, z2 = z, z + 1 + vmap_fx = vmap(fx) + return vmap_fx(z1, z2) + +@jit +def test_while(num, init_val): + def cond_fun(carry): + i, cumsum = carry + return i < num + + def body_fun(carry): + i, cumsum = carry + cumsum += i + return i + 1, cumsum + + return jax.lax.while_loop(cond_fun, body_fun, (0, init_val)) + + + +@using_cprofile +def main(): + z = jnp.zeros((100000, )) + + num = 100 + + nums = jnp.arange(num) * 10 + + f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile() + def test_time(*args): + return f(*args) + + print(test_time(nums, z)) + + # + # + # for i in range(10): + # num = 10 ** i + # st = time() + # res = test_time(num, z) + # print(res) + # t = time() - st + # print(f"num: {num}, time: {t}") + +if __name__ == '__main__': + main()