gpu slice is very slow. fix this problem
This commit is contained in:
@@ -94,7 +94,10 @@ class FunctionFactory:
|
||||
default_weight=self.weight_mean
|
||||
)
|
||||
if self.debug:
|
||||
return lambda *args: func(*args)
|
||||
def debug_initialize(*args):
|
||||
return func(*args)
|
||||
|
||||
return debug_initialize
|
||||
else:
|
||||
return func
|
||||
|
||||
@@ -117,8 +120,9 @@ class FunctionFactory:
|
||||
|
||||
# precompile other functions used in jax
|
||||
key = jax.random.PRNGKey(0)
|
||||
_ = jax.random.split(key, 2)
|
||||
_ = jax.random.split(key, 3)
|
||||
_ = jax.random.split(key, self.pop_size * 2)
|
||||
_ = jax.random.split(key, self.pop_size)
|
||||
|
||||
print("end precompile")
|
||||
|
||||
@@ -171,7 +175,14 @@ class FunctionFactory:
|
||||
key = ('mutate', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_mutate(n)
|
||||
return self.compiled_function[key]
|
||||
if self.debug:
|
||||
def debug_mutate(*args):
|
||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||
|
||||
return debug_mutate
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_distance_with_args(self):
|
||||
func = partial(
|
||||
@@ -203,7 +214,17 @@ class FunctionFactory:
|
||||
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
||||
if key1 not in self.compiled_function:
|
||||
self.compile_distance(n)
|
||||
return self.compiled_function[key1], self.compiled_function[key2]
|
||||
if self.debug:
|
||||
|
||||
def debug_o2o_distance(*args):
|
||||
return self.compiled_function[key1](*args).block_until_ready()
|
||||
|
||||
def debug_o2m_distance(*args):
|
||||
return self.compiled_function[key2](*args).block_until_ready()
|
||||
|
||||
return debug_o2o_distance, debug_o2m_distance
|
||||
else:
|
||||
return self.compiled_function[key1], self.compiled_function[key2]
|
||||
|
||||
def create_crossover_with_args(self):
|
||||
self.crossover_with_args = crossover
|
||||
@@ -223,7 +244,15 @@ class FunctionFactory:
|
||||
key = ('crossover', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_crossover(n)
|
||||
return self.compiled_function[key]
|
||||
if self.debug:
|
||||
|
||||
def debug_crossover(*args):
|
||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||
|
||||
return debug_crossover
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_topological_sort_with_args(self):
|
||||
self.topological_sort_with_args = topological_sort
|
||||
@@ -303,7 +332,13 @@ class FunctionFactory:
|
||||
key = ('pop_batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_pop_batch_forward(n)
|
||||
return self.compiled_function[key]
|
||||
if self.debug:
|
||||
def debug_pop_batch_forward(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_pop_batch_forward
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask(self, pop_nodes, pop_connections):
|
||||
n = pop_nodes.shape[1]
|
||||
@@ -312,7 +347,10 @@ class FunctionFactory:
|
||||
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
|
||||
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
|
||||
return debug_forward
|
||||
|
||||
# return partial(
|
||||
# forward_func,
|
||||
|
||||
Reference in New Issue
Block a user