precompile jax.random.split

This commit is contained in:
wls2002
2023-05-10 15:20:42 +08:00
parent 0fdc856f2d
commit 9dfa904ce5
4 changed files with 19 additions and 14 deletions

View File

@@ -3,6 +3,7 @@ Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import jax.random
import numpy as np
from jax import jit, vmap
@@ -113,6 +114,12 @@ class FunctionFactory:
self.compile_topological_sort(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n)
# precompile other functions used in jax
key = jax.random.PRNGKey(0)
_ = jax.random.split(key, 2)
_ = jax.random.split(key, self.pop_size * 2)
print("end precompile")
def create_mutate_with_args(self):