use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -3,6 +3,7 @@ import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
from functools import partial
from examples.time_utils import using_cprofile
@@ -16,28 +17,43 @@ def func(x, y):
return x * y
def func2(x, y, s):
"""
:param x: (100, )
:param y: (100,
:return:
"""
if s == '123':
return 0
else:
return x * y
@jit
def func3(x, y):
return func2(x, y, '123')
# @using_cprofile
def main():
key = jax.random.PRNGKey(42)
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,))
x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,))
jit_func = jit(func)
z = jit_func(x1, y1)
print(z)
jit_lower_func = jit(func).lower(x1, y1).compile()
jit_lower_func = jit(func).lower(1, 2).compile()
print(type(jit_lower_func))
import pickle
print(jit_lower_func.memory_analysis())
with open('jit_function.pkl', 'wb') as f:
pickle.dump(jit_lower_func, f)
jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile()
print(jit_compiled_func2(x1, y1))
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb'))
# print(jit_compiled_func2(x1, y1))
print(jit_lower_func(x1, y1))
print(new_jit_lower_func(x1, y1))
f = func3.lower(x1, y1).compile()
print(f(x1, y1))
# print(jit_lower_func(x1, y1))
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
# print(jit_lower_func(x2, y2))

View File

@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
return fitnesses.tolist() # returns a list
@using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
# @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
pipeline = Pipeline(config, seed=11323)