gpu slice is very slow. fixed this problem

This commit is contained in:
wls2002
2023-05-10 18:36:22 +08:00
parent 3f37d79d06
commit 097bbf6631

View File

@@ -89,32 +89,20 @@ class Pipeline:
# prepare elitism mask and crossover pair # prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False) elitism_mask = np.full(self.pop_size, False)
def aux3(): for i, pair in enumerate(crossover_pair):
nonlocal crossover_pair if not isinstance(pair, tuple): # elitism
for i, pair in enumerate(crossover_pair): elitism_mask[i] = True
if not isinstance(pair, tuple): # elitism crossover_pair[i] = (pair, pair)
elitism_mask[i] = True crossover_pair = np.array(crossover_pair)
crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair)
return elitism_mask
def aux4(): crossover_rand_keys = jax.random.split(k1, self.pop_size)
crossover_rand_keys = jax.random.split(k1, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size)
mutate_rand_keys = jax.random.split(k2, self.pop_size)
return crossover_rand_keys, mutate_rand_keys
elitism_mask = aux3() # batch crossover
crossover_rand_keys, mutate_rand_keys = aux4() wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
def aux2(): lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
# batch crossover lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
return wpn, wpc, lpn, lpc
wpn, wpc, lpn, lpc = aux2()
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections lpc) # new pop nodes, new pop connections
@@ -124,14 +112,10 @@ class Pipeline:
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate # elitism don't mutate
def axu(): npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
nonlocal npn, npc, m_npn, m_npc
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
axu()
def expand(self): def expand(self):
""" """