gpu slice is very slow. fixed this problem
This commit is contained in:
@@ -89,32 +89,20 @@ class Pipeline:
|
|||||||
# prepare elitism mask and crossover pair
|
# prepare elitism mask and crossover pair
|
||||||
elitism_mask = np.full(self.pop_size, False)
|
elitism_mask = np.full(self.pop_size, False)
|
||||||
|
|
||||||
def aux3():
|
for i, pair in enumerate(crossover_pair):
|
||||||
nonlocal crossover_pair
|
if not isinstance(pair, tuple): # elitism
|
||||||
for i, pair in enumerate(crossover_pair):
|
elitism_mask[i] = True
|
||||||
if not isinstance(pair, tuple): # elitism
|
crossover_pair[i] = (pair, pair)
|
||||||
elitism_mask[i] = True
|
crossover_pair = np.array(crossover_pair)
|
||||||
crossover_pair[i] = (pair, pair)
|
|
||||||
crossover_pair = np.array(crossover_pair)
|
|
||||||
return elitism_mask
|
|
||||||
|
|
||||||
def aux4():
|
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
|
||||||
return crossover_rand_keys, mutate_rand_keys
|
|
||||||
|
|
||||||
elitism_mask = aux3()
|
# batch crossover
|
||||||
crossover_rand_keys, mutate_rand_keys = aux4()
|
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
||||||
|
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||||
def aux2():
|
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||||
# batch crossover
|
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
|
||||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
|
||||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
|
||||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
|
||||||
return wpn, wpc, lpn, lpc
|
|
||||||
|
|
||||||
wpn, wpc, lpn, lpc = aux2()
|
|
||||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
@@ -124,14 +112,10 @@ class Pipeline:
|
|||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
def axu():
|
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||||
nonlocal npn, npc, m_npn, m_npc
|
|
||||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
|
||||||
|
|
||||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||||
|
|
||||||
axu()
|
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user