update to test in servers

This commit is contained in:
wls2002
2023-05-10 22:33:51 +08:00
parent ce35b01896
commit b271a56827
9 changed files with 112 additions and 34 deletions

View File

@@ -114,7 +114,7 @@ class FunctionFactory:
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
self.compile_topological_sort(n)
self.compile_topological_sort_batch(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n)
@@ -259,9 +259,8 @@ class FunctionFactory:
def compile_topological_sort(self, n):
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
@@ -271,6 +270,20 @@ class FunctionFactory:
self.compile_topological_sort(n)
return self.compiled_function[key]
def compile_topological_sort_batch(self, n):
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort_batch', n)] = func
def create_topological_sort_batch(self, n):
key = ('topological_sort_batch', n)
if key not in self.compiled_function:
self.compile_topological_sort_batch(n)
return self.compiled_function[key]
def create_single_forward_with_args(self):
func = partial(
forward_single,
@@ -315,6 +328,18 @@ class FunctionFactory:
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
def create_batch_forward(self, n):
key = ('batch_forward', n)
if key not in self.compiled_function:
self.compile_batch_forward(n)
if self.debug:
def debug_batch_forward(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_batch_forward
else:
return self.compiled_function[key]
def compile_pop_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
@@ -340,9 +365,9 @@ class FunctionFactory:
else:
return self.compiled_function[key]
def ask(self, pop_nodes, pop_connections):
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
ts = self.create_topological_sort(n)
ts = self.create_topological_sort_batch(n)
pop_cal_seqs = ts(pop_nodes, pop_connections)
forward_func = self.create_pop_batch_forward(n)
@@ -352,9 +377,13 @@ class FunctionFactory:
return debug_forward
# return partial(
# forward_func,
# cal_seqs=pop_cal_seqs,
# nodes=pop_nodes,
# connections=pop_connections
# )
def ask_batch_forward(self, nodes, connections):
n = nodes.shape[0]
ts = self.create_topological_sort(n)
cal_seqs = ts(nodes, connections)
forward_func = self.create_batch_forward(n)
def debug_forward(inputs):
return forward_func(inputs, cal_seqs, nodes, connections)
return debug_forward