precompile jax.random.split
This commit is contained in:
@@ -3,6 +3,7 @@ Lowers, compiles, and creates functions used in the NEAT pipeline.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import jax.random
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
@@ -113,6 +114,12 @@ class FunctionFactory:
|
||||
self.compile_topological_sort(n)
|
||||
self.compile_pop_batch_forward(n)
|
||||
n = int(self.expand_coe * n)
|
||||
|
||||
# precompile other functions used in jax
|
||||
key = jax.random.PRNGKey(0)
|
||||
_ = jax.random.split(key, 2)
|
||||
_ = jax.random.split(key, self.pop_size * 2)
|
||||
|
||||
print("end precompile")
|
||||
|
||||
def create_mutate_with_args(self):
|
||||
|
||||
Reference in New Issue
Block a user