From 097bbf6631de3f0e3f3ab8235f058544f5bb7e4a Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 10 May 2023 18:36:22 +0800 Subject: [PATCH] gpu slice is very slow. fixed this problem --- algorithms/neat/pipeline.py | 46 ++++++++++++------------------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 5caf20c..6dcbaab 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -89,32 +89,20 @@ class Pipeline: # prepare elitism mask and crossover pair elitism_mask = np.full(self.pop_size, False) - def aux3(): - nonlocal crossover_pair - for i, pair in enumerate(crossover_pair): - if not isinstance(pair, tuple): # elitism - elitism_mask[i] = True - crossover_pair[i] = (pair, pair) - crossover_pair = np.array(crossover_pair) - return elitism_mask + for i, pair in enumerate(crossover_pair): + if not isinstance(pair, tuple): # elitism + elitism_mask[i] = True + crossover_pair[i] = (pair, pair) + crossover_pair = np.array(crossover_pair) - def aux4(): - crossover_rand_keys = jax.random.split(k1, self.pop_size) - mutate_rand_keys = jax.random.split(k2, self.pop_size) - return crossover_rand_keys, mutate_rand_keys + crossover_rand_keys = jax.random.split(k1, self.pop_size) + mutate_rand_keys = jax.random.split(k2, self.pop_size) - elitism_mask = aux3() - crossover_rand_keys, mutate_rand_keys = aux4() - - def aux2(): - # batch crossover - wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes - wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections - lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes - lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections - return wpn, wpc, lpn, lpc - - wpn, wpc, lpn, lpc = aux2() + # batch crossover + wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes + wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections + lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes + lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections @@ -124,14 +112,10 @@ class Pipeline: m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes # elitism don't mutate - def axu(): - nonlocal npn, npc, m_npn, m_npc - npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) + npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) - self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) - - axu() + self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) + self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) def expand(self): """