From 3f37d79d06dd47cfb75227437736716629a9aff7 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 10 May 2023 17:48:07 +0800 Subject: [PATCH] gpu slice is very slow. fix this problem --- algorithms/neat/function_factory.py | 52 +++++++++++++++++++++++++---- algorithms/neat/pipeline.py | 52 +++++++++++++++++++---------- examples/xor.py | 6 ++-- utils/default_config.json | 4 +-- 4 files changed, 84 insertions(+), 30 deletions(-) diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index e914a69..d024d9c 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -94,7 +94,10 @@ class FunctionFactory: default_weight=self.weight_mean ) if self.debug: - return lambda *args: func(*args) + def debug_initialize(*args): + return func(*args) + + return debug_initialize else: return func @@ -117,8 +120,9 @@ class FunctionFactory: # precompile other functions used in jax key = jax.random.PRNGKey(0) - _ = jax.random.split(key, 2) + _ = jax.random.split(key, 3) _ = jax.random.split(key, self.pop_size * 2) + _ = jax.random.split(key, self.pop_size) print("end precompile") @@ -171,7 +175,14 @@ class FunctionFactory: key = ('mutate', n) if key not in self.compiled_function: self.compile_mutate(n) - return self.compiled_function[key] + if self.debug: + def debug_mutate(*args): + res_nodes, res_connections = self.compiled_function[key](*args) + return res_nodes.block_until_ready(), res_connections.block_until_ready() + + return debug_mutate + else: + return self.compiled_function[key] def create_distance_with_args(self): func = partial( @@ -203,7 +214,17 @@ class FunctionFactory: key1, key2 = ('o2o_distance', n), ('o2m_distance', n) if key1 not in self.compiled_function: self.compile_distance(n) - return self.compiled_function[key1], self.compiled_function[key2] + if self.debug: + + def debug_o2o_distance(*args): + return self.compiled_function[key1](*args).block_until_ready() + + def debug_o2m_distance(*args): + return self.compiled_function[key2](*args).block_until_ready() + + return debug_o2o_distance, debug_o2m_distance + else: + return self.compiled_function[key1], self.compiled_function[key2] def create_crossover_with_args(self): self.crossover_with_args = crossover @@ -223,7 +244,15 @@ class FunctionFactory: key = ('crossover', n) if key not in self.compiled_function: self.compile_crossover(n) - return self.compiled_function[key] + if self.debug: + + def debug_crossover(*args): + res_nodes, res_connections = self.compiled_function[key](*args) + return res_nodes.block_until_ready(), res_connections.block_until_ready() + + return debug_crossover + else: + return self.compiled_function[key] def create_topological_sort_with_args(self): self.topological_sort_with_args = topological_sort @@ -303,7 +332,13 @@ class FunctionFactory: key = ('pop_batch_forward', n) if key not in self.compiled_function: self.compile_pop_batch_forward(n) - return self.compiled_function[key] + if self.debug: + def debug_pop_batch_forward(*args): + return self.compiled_function[key](*args).block_until_ready() + + return debug_pop_batch_forward + else: + return self.compiled_function[key] def ask(self, pop_nodes, pop_connections): n = pop_nodes.shape[1] @@ -312,7 +347,10 @@ class FunctionFactory: forward_func = self.create_pop_batch_forward(n) - return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections) + def debug_forward(inputs): + return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections) + + return debug_forward # return partial( # forward_func, diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 00e6dae..5caf20c 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -16,7 +16,7 @@ class Pipeline: """ def __init__(self, config, seed=42): - self.function_factory = FunctionFactory(config) + self.function_factory = FunctionFactory(config, debug=True) self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) @@ -83,27 +83,38 @@ class Pipeline: """ assert self.pop_nodes.shape[0] == self.pop_size - k, self.randkey = jax.random.split(self.randkey, 2) + k1, k2, self.randkey = jax.random.split(self.randkey, 3) # crossover # prepare elitism mask and crossover pair elitism_mask = np.full(self.pop_size, False) - for i, pair in enumerate(crossover_pair): - if not isinstance(pair, tuple): # elitism - elitism_mask[i] = True - crossover_pair[i] = (pair, pair) - crossover_pair = np.array(crossover_pair) + def aux3(): + nonlocal crossover_pair + for i, pair in enumerate(crossover_pair): + if not isinstance(pair, tuple): # elitism + elitism_mask[i] = True + crossover_pair[i] = (pair, pair) + crossover_pair = np.array(crossover_pair) + return elitism_mask - total_keys = jax.random.split(k, self.pop_size * 2) - crossover_rand_keys = total_keys[:self.pop_size, :] - mutate_rand_keys = total_keys[self.pop_size:, :] + def aux4(): + crossover_rand_keys = jax.random.split(k1, self.pop_size) + mutate_rand_keys = jax.random.split(k2, self.pop_size) + return crossover_rand_keys, mutate_rand_keys - # batch crossover - wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes - wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections - lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes - lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections + elitism_mask = aux3() + crossover_rand_keys, mutate_rand_keys = aux4() + + def aux2(): + # batch crossover + wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes + wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections + lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes + lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections + return wpn, wpc, lpn, lpc + + wpn, wpc, lpn, lpc = aux2() npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections @@ -113,9 +124,14 @@ class Pipeline: m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes # elitism don't mutate - npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) - self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + def axu(): + nonlocal npn, npc, m_npn, m_npc + npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) + + self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) + self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + + axu() def expand(self): """ diff --git a/examples/xor.py b/examples/xor.py index 13c40cc..bfec783 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -23,11 +23,11 @@ def evaluate(forward_func: Callable) -> List[float]: return fitnesses.tolist() # returns a list -@using_cprofile -# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +# @using_cprofile +@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() - pipeline = Pipeline(config, seed=114514) + pipeline = Pipeline(config, seed=11454) pipeline.auto_run(evaluate) diff --git a/utils/default_config.json b/utils/default_config.json index 2f7fb8a..1b2be0e 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -5,14 +5,14 @@ "problem_batch": 4, "init_maximum_nodes": 10, "expands_coe": 2, - "pre_compile_times": 2 + "pre_compile_times": 3 }, "neat": { "population": { "fitness_criterion": "max", "fitness_threshold": 76, "generation_limit": 100, - "pop_size": 1000, + "pop_size": 2000, "reset_on_extinction": "False" }, "gene": {