precompile jax.random.split

This commit is contained in:
wls2002
2023-05-10 15:20:42 +08:00
parent 0fdc856f2d
commit 9dfa904ce5
4 changed files with 19 additions and 14 deletions

View File

@@ -5,8 +5,9 @@ import jax
import numpy as np
from .species import SpeciesController
from .genome import expand, expand_single, create_forward_function
from .genome import expand, expand_single
from .function_factory import FunctionFactory
from examples.time_utils import using_cprofile
class Pipeline:
@@ -15,12 +16,11 @@ class Pipeline:
"""
def __init__(self, config, seed=42):
self.function_factory = FunctionFactory(config)
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config
self.function_factory = FunctionFactory(config)
self.N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size
@@ -45,11 +45,6 @@ class Pipeline:
"""
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
#
# func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
# batch=batch)
# return func
def tell(self, fitnesses):
self.generation += 1
@@ -80,6 +75,7 @@ class Pipeline:
self.tell(fitnesses)
print("Generation limit reached!")
# @using_cprofile
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
"""
create the next generation
@@ -87,7 +83,7 @@ class Pipeline:
"""
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
k, self.randkey = jax.random.split(self.randkey, 2)
# crossover
# prepare elitism mask and crossover pair
@@ -99,7 +95,10 @@ class Pipeline:
crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
total_keys = jax.random.split(k, self.pop_size * 2)
crossover_rand_keys = total_keys[:self.pop_size, :]
mutate_rand_keys = total_keys[self.pop_size:, :]
# batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
@@ -109,7 +108,6 @@ class Pipeline:
lpc) # new pop nodes, new pop connections
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes