gpu slice is very slow. fix this problem
This commit is contained in:
@@ -94,7 +94,10 @@ class FunctionFactory:
|
|||||||
default_weight=self.weight_mean
|
default_weight=self.weight_mean
|
||||||
)
|
)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
return lambda *args: func(*args)
|
def debug_initialize(*args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
return debug_initialize
|
||||||
else:
|
else:
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@@ -117,8 +120,9 @@ class FunctionFactory:
|
|||||||
|
|
||||||
# precompile other functions used in jax
|
# precompile other functions used in jax
|
||||||
key = jax.random.PRNGKey(0)
|
key = jax.random.PRNGKey(0)
|
||||||
_ = jax.random.split(key, 2)
|
_ = jax.random.split(key, 3)
|
||||||
_ = jax.random.split(key, self.pop_size * 2)
|
_ = jax.random.split(key, self.pop_size * 2)
|
||||||
|
_ = jax.random.split(key, self.pop_size)
|
||||||
|
|
||||||
print("end precompile")
|
print("end precompile")
|
||||||
|
|
||||||
@@ -171,7 +175,14 @@ class FunctionFactory:
|
|||||||
key = ('mutate', n)
|
key = ('mutate', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_mutate(n)
|
self.compile_mutate(n)
|
||||||
return self.compiled_function[key]
|
if self.debug:
|
||||||
|
def debug_mutate(*args):
|
||||||
|
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||||
|
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||||
|
|
||||||
|
return debug_mutate
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def create_distance_with_args(self):
|
def create_distance_with_args(self):
|
||||||
func = partial(
|
func = partial(
|
||||||
@@ -203,7 +214,17 @@ class FunctionFactory:
|
|||||||
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
||||||
if key1 not in self.compiled_function:
|
if key1 not in self.compiled_function:
|
||||||
self.compile_distance(n)
|
self.compile_distance(n)
|
||||||
return self.compiled_function[key1], self.compiled_function[key2]
|
if self.debug:
|
||||||
|
|
||||||
|
def debug_o2o_distance(*args):
|
||||||
|
return self.compiled_function[key1](*args).block_until_ready()
|
||||||
|
|
||||||
|
def debug_o2m_distance(*args):
|
||||||
|
return self.compiled_function[key2](*args).block_until_ready()
|
||||||
|
|
||||||
|
return debug_o2o_distance, debug_o2m_distance
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key1], self.compiled_function[key2]
|
||||||
|
|
||||||
def create_crossover_with_args(self):
|
def create_crossover_with_args(self):
|
||||||
self.crossover_with_args = crossover
|
self.crossover_with_args = crossover
|
||||||
@@ -223,7 +244,15 @@ class FunctionFactory:
|
|||||||
key = ('crossover', n)
|
key = ('crossover', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_crossover(n)
|
self.compile_crossover(n)
|
||||||
return self.compiled_function[key]
|
if self.debug:
|
||||||
|
|
||||||
|
def debug_crossover(*args):
|
||||||
|
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||||
|
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||||
|
|
||||||
|
return debug_crossover
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def create_topological_sort_with_args(self):
|
def create_topological_sort_with_args(self):
|
||||||
self.topological_sort_with_args = topological_sort
|
self.topological_sort_with_args = topological_sort
|
||||||
@@ -303,7 +332,13 @@ class FunctionFactory:
|
|||||||
key = ('pop_batch_forward', n)
|
key = ('pop_batch_forward', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_pop_batch_forward(n)
|
self.compile_pop_batch_forward(n)
|
||||||
return self.compiled_function[key]
|
if self.debug:
|
||||||
|
def debug_pop_batch_forward(*args):
|
||||||
|
return self.compiled_function[key](*args).block_until_ready()
|
||||||
|
|
||||||
|
return debug_pop_batch_forward
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def ask(self, pop_nodes, pop_connections):
|
def ask(self, pop_nodes, pop_connections):
|
||||||
n = pop_nodes.shape[1]
|
n = pop_nodes.shape[1]
|
||||||
@@ -312,7 +347,10 @@ class FunctionFactory:
|
|||||||
|
|
||||||
forward_func = self.create_pop_batch_forward(n)
|
forward_func = self.create_pop_batch_forward(n)
|
||||||
|
|
||||||
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
def debug_forward(inputs):
|
||||||
|
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
||||||
|
|
||||||
|
return debug_forward
|
||||||
|
|
||||||
# return partial(
|
# return partial(
|
||||||
# forward_func,
|
# forward_func,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, seed=42):
|
def __init__(self, config, seed=42):
|
||||||
self.function_factory = FunctionFactory(config)
|
self.function_factory = FunctionFactory(config, debug=True)
|
||||||
self.randkey = jax.random.PRNGKey(seed)
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
@@ -83,27 +83,38 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.pop_nodes.shape[0] == self.pop_size
|
assert self.pop_nodes.shape[0] == self.pop_size
|
||||||
k, self.randkey = jax.random.split(self.randkey, 2)
|
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||||
|
|
||||||
# crossover
|
# crossover
|
||||||
# prepare elitism mask and crossover pair
|
# prepare elitism mask and crossover pair
|
||||||
elitism_mask = np.full(self.pop_size, False)
|
elitism_mask = np.full(self.pop_size, False)
|
||||||
|
|
||||||
for i, pair in enumerate(crossover_pair):
|
def aux3():
|
||||||
if not isinstance(pair, tuple): # elitism
|
nonlocal crossover_pair
|
||||||
elitism_mask[i] = True
|
for i, pair in enumerate(crossover_pair):
|
||||||
crossover_pair[i] = (pair, pair)
|
if not isinstance(pair, tuple): # elitism
|
||||||
crossover_pair = np.array(crossover_pair)
|
elitism_mask[i] = True
|
||||||
|
crossover_pair[i] = (pair, pair)
|
||||||
|
crossover_pair = np.array(crossover_pair)
|
||||||
|
return elitism_mask
|
||||||
|
|
||||||
total_keys = jax.random.split(k, self.pop_size * 2)
|
def aux4():
|
||||||
crossover_rand_keys = total_keys[:self.pop_size, :]
|
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||||
mutate_rand_keys = total_keys[self.pop_size:, :]
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
|
return crossover_rand_keys, mutate_rand_keys
|
||||||
|
|
||||||
# batch crossover
|
elitism_mask = aux3()
|
||||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
crossover_rand_keys, mutate_rand_keys = aux4()
|
||||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
|
||||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
def aux2():
|
||||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
# batch crossover
|
||||||
|
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
||||||
|
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||||
|
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||||
|
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||||
|
return wpn, wpc, lpn, lpc
|
||||||
|
|
||||||
|
wpn, wpc, lpn, lpc = aux2()
|
||||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
@@ -113,9 +124,14 @@ class Pipeline:
|
|||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
def axu():
|
||||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
nonlocal npn, npc, m_npn, m_npc
|
||||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||||
|
|
||||||
|
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||||
|
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||||
|
|
||||||
|
axu()
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ def evaluate(forward_func: Callable) -> List[float]:
|
|||||||
return fitnesses.tolist() # returns a list
|
return fitnesses.tolist() # returns a list
|
||||||
|
|
||||||
|
|
||||||
@using_cprofile
|
# @using_cprofile
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
pipeline = Pipeline(config, seed=114514)
|
pipeline = Pipeline(config, seed=11454)
|
||||||
pipeline.auto_run(evaluate)
|
pipeline.auto_run(evaluate)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,14 @@
|
|||||||
"problem_batch": 4,
|
"problem_batch": 4,
|
||||||
"init_maximum_nodes": 10,
|
"init_maximum_nodes": 10,
|
||||||
"expands_coe": 2,
|
"expands_coe": 2,
|
||||||
"pre_compile_times": 2
|
"pre_compile_times": 3
|
||||||
},
|
},
|
||||||
"neat": {
|
"neat": {
|
||||||
"population": {
|
"population": {
|
||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": 76,
|
"fitness_threshold": 76,
|
||||||
"generation_limit": 100,
|
"generation_limit": 100,
|
||||||
"pop_size": 1000,
|
"pop_size": 2000,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
"gene": {
|
"gene": {
|
||||||
|
|||||||
Reference in New Issue
Block a user