From 9dfa904ce5f29bda5b73c0318f1e9f3b46c2e51d Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 10 May 2023 15:20:42 +0800 Subject: [PATCH] precompile jax.random.split --- algorithms/neat/function_factory.py | 7 +++++++ algorithms/neat/pipeline.py | 20 +++++++++----------- algorithms/neat/species.py | 2 +- utils/default_config.json | 4 ++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index 27cc42c..e914a69 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -3,6 +3,7 @@ Lowers, compiles, and creates functions used in the NEAT pipeline. """ from functools import partial +import jax.random import numpy as np from jax import jit, vmap @@ -113,6 +114,12 @@ class FunctionFactory: self.compile_topological_sort(n) self.compile_pop_batch_forward(n) n = int(self.expand_coe * n) + + # precompile other functions used in jax + key = jax.random.PRNGKey(0) + _ = jax.random.split(key, 2) + _ = jax.random.split(key, self.pop_size * 2) + print("end precompile") def create_mutate_with_args(self): diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index ae4d097..00e6dae 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -5,8 +5,9 @@ import jax import numpy as np from .species import SpeciesController -from .genome import expand, expand_single, create_forward_function +from .genome import expand, expand_single from .function_factory import FunctionFactory +from examples.time_utils import using_cprofile class Pipeline: @@ -15,12 +16,11 @@ class Pipeline: """ def __init__(self, config, seed=42): + self.function_factory = FunctionFactory(config) self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) self.config = config - self.function_factory = FunctionFactory(config) - self.N = config.basic.init_maximum_nodes self.expand_coe = config.basic.expands_coe self.pop_size = config.neat.population.pop_size @@ -45,11 +45,6 @@ class Pipeline: """ return self.function_factory.ask(self.pop_nodes, self.pop_connections) - # - # func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx, - # batch=batch) - # return func - def tell(self, fitnesses): self.generation += 1 @@ -80,6 +75,7 @@ class Pipeline: self.tell(fitnesses) print("Generation limit reached!") + # @using_cprofile def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: """ create the next generation @@ -87,7 +83,7 @@ class Pipeline: """ assert self.pop_nodes.shape[0] == self.pop_size - k1, k2, self.randkey = jax.random.split(self.randkey, 3) + k, self.randkey = jax.random.split(self.randkey, 2) # crossover # prepare elitism mask and crossover pair @@ -99,7 +95,10 @@ class Pipeline: crossover_pair[i] = (pair, pair) crossover_pair = np.array(crossover_pair) - crossover_rand_keys = jax.random.split(k1, self.pop_size) + total_keys = jax.random.split(k, self.pop_size * 2) + crossover_rand_keys = total_keys[:self.pop_size, :] + mutate_rand_keys = total_keys[self.pop_size:, :] + # batch crossover wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections @@ -109,7 +108,6 @@ class Pipeline: lpc) # new pop nodes, new pop connections # mutate - mutate_rand_keys = jax.random.split(k2, self.pop_size) new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 035ef58..9bdbc82 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -110,7 +110,7 @@ class SpeciesController: new_members[species_id].append(i) unspeciated[i] = False - # Second, slowly match the lonely population to new-created species. + # Second, slowly match the lonely population to new-created species.s # lonely genome is proved to be not compatible with any previous species, so they only need to be compared with # the new representatives. for i in range(pop_nodes.shape[0]): diff --git a/utils/default_config.json b/utils/default_config.json index b10133d..2f7fb8a 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -3,9 +3,9 @@ "num_inputs": 2, "num_outputs": 1, "problem_batch": 4, - "init_maximum_nodes": 20, + "init_maximum_nodes": 10, "expands_coe": 2, - "pre_compile_times": 0 + "pre_compile_times": 2 }, "neat": { "population": {