gpu slice is very slow. fix this problem

This commit is contained in:
wls2002
2023-05-10 17:48:07 +08:00
parent 9dfa904ce5
commit 3f37d79d06
4 changed files with 84 additions and 30 deletions

View File

@@ -94,7 +94,10 @@ class FunctionFactory:
default_weight=self.weight_mean
)
if self.debug:
return lambda *args: func(*args)
def debug_initialize(*args):
return func(*args)
return debug_initialize
else:
return func
@@ -117,8 +120,9 @@ class FunctionFactory:
# precompile other functions used in jax
key = jax.random.PRNGKey(0)
_ = jax.random.split(key, 2)
_ = jax.random.split(key, 3)
_ = jax.random.split(key, self.pop_size * 2)
_ = jax.random.split(key, self.pop_size)
print("end precompile")
@@ -171,7 +175,14 @@ class FunctionFactory:
key = ('mutate', n)
if key not in self.compiled_function:
self.compile_mutate(n)
return self.compiled_function[key]
if self.debug:
def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_mutate
else:
return self.compiled_function[key]
def create_distance_with_args(self):
func = partial(
@@ -203,7 +214,17 @@ class FunctionFactory:
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
if key1 not in self.compiled_function:
self.compile_distance(n)
return self.compiled_function[key1], self.compiled_function[key2]
if self.debug:
def debug_o2o_distance(*args):
return self.compiled_function[key1](*args).block_until_ready()
def debug_o2m_distance(*args):
return self.compiled_function[key2](*args).block_until_ready()
return debug_o2o_distance, debug_o2m_distance
else:
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
@@ -223,7 +244,15 @@ class FunctionFactory:
key = ('crossover', n)
if key not in self.compiled_function:
self.compile_crossover(n)
return self.compiled_function[key]
if self.debug:
def debug_crossover(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_crossover
else:
return self.compiled_function[key]
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
@@ -303,7 +332,13 @@ class FunctionFactory:
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
return self.compiled_function[key]
if self.debug:
def debug_pop_batch_forward(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_pop_batch_forward
else:
return self.compiled_function[key]
def ask(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
@@ -312,7 +347,10 @@ class FunctionFactory:
forward_func = self.create_pop_batch_forward(n)
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
def debug_forward(inputs):
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
return debug_forward
# return partial(
# forward_func,