precompile jax.random.split

This commit is contained in:
wls2002
2023-05-10 15:20:42 +08:00
parent 0fdc856f2d
commit 9dfa904ce5
4 changed files with 19 additions and 14 deletions

View File

@@ -3,6 +3,7 @@ Lowers, compiles, and creates functions used in the NEAT pipeline.
""" """
from functools import partial from functools import partial
import jax.random
import numpy as np import numpy as np
from jax import jit, vmap from jax import jit, vmap
@@ -113,6 +114,12 @@ class FunctionFactory:
self.compile_topological_sort(n) self.compile_topological_sort(n)
self.compile_pop_batch_forward(n) self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n) n = int(self.expand_coe * n)
# precompile other functions used in jax
key = jax.random.PRNGKey(0)
_ = jax.random.split(key, 2)
_ = jax.random.split(key, self.pop_size * 2)
print("end precompile") print("end precompile")
def create_mutate_with_args(self): def create_mutate_with_args(self):

View File

@@ -5,8 +5,9 @@ import jax
import numpy as np import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import expand, expand_single, create_forward_function from .genome import expand, expand_single
from .function_factory import FunctionFactory from .function_factory import FunctionFactory
from examples.time_utils import using_cprofile
class Pipeline: class Pipeline:
@@ -15,12 +16,11 @@ class Pipeline:
""" """
def __init__(self, config, seed=42): def __init__(self, config, seed=42):
self.function_factory = FunctionFactory(config)
self.randkey = jax.random.PRNGKey(seed) self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed) np.random.seed(seed)
self.config = config self.config = config
self.function_factory = FunctionFactory(config)
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
@@ -45,11 +45,6 @@ class Pipeline:
""" """
return self.function_factory.ask(self.pop_nodes, self.pop_connections) return self.function_factory.ask(self.pop_nodes, self.pop_connections)
#
# func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
# batch=batch)
# return func
def tell(self, fitnesses): def tell(self, fitnesses):
self.generation += 1 self.generation += 1
@@ -80,6 +75,7 @@ class Pipeline:
self.tell(fitnesses) self.tell(fitnesses)
print("Generation limit reached!") print("Generation limit reached!")
# @using_cprofile
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
""" """
create the next generation create the next generation
@@ -87,7 +83,7 @@ class Pipeline:
""" """
assert self.pop_nodes.shape[0] == self.pop_size assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3) k, self.randkey = jax.random.split(self.randkey, 2)
# crossover # crossover
# prepare elitism mask and crossover pair # prepare elitism mask and crossover pair
@@ -99,7 +95,10 @@ class Pipeline:
crossover_pair[i] = (pair, pair) crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair) crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size) total_keys = jax.random.split(k, self.pop_size * 2)
crossover_rand_keys = total_keys[:self.pop_size, :]
mutate_rand_keys = total_keys[self.pop_size:, :]
# batch crossover # batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
@@ -109,7 +108,6 @@ class Pipeline:
lpc) # new pop nodes, new pop connections lpc) # new pop nodes, new pop connections
# mutate # mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes

View File

@@ -110,7 +110,7 @@ class SpeciesController:
new_members[species_id].append(i) new_members[species_id].append(i)
unspeciated[i] = False unspeciated[i] = False
# Second, slowly match the lonely population to new-created species. # Second, slowly match the lonely population to new-created species.s
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with # lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives. # the new representatives.
for i in range(pop_nodes.shape[0]): for i in range(pop_nodes.shape[0]):

View File

@@ -3,9 +3,9 @@
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"problem_batch": 4, "problem_batch": 4,
"init_maximum_nodes": 20, "init_maximum_nodes": 10,
"expands_coe": 2, "expands_coe": 2,
"pre_compile_times": 0 "pre_compile_times": 2
}, },
"neat": { "neat": {
"population": { "population": {