prepare for experiment
This commit is contained in:
@@ -19,7 +19,7 @@ class FunctionFactory:
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
self.time_cost = {}
|
||||
self.compile_time = 0
|
||||
|
||||
self.load_config_vals(config)
|
||||
|
||||
@@ -150,6 +150,8 @@ class FunctionFactory:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_update_speciate(self, N, C, S):
|
||||
s = time.time()
|
||||
|
||||
func = self.update_speciate_with_args
|
||||
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
||||
@@ -177,16 +179,22 @@ class FunctionFactory:
|
||||
).compile()
|
||||
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_with_args(self):
|
||||
self.topological_sort_with_args = topological_sort
|
||||
|
||||
def compile_topological_sort(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort(self, n):
|
||||
key = ('topological_sort', n)
|
||||
if key not in self.compiled_function:
|
||||
@@ -194,6 +202,8 @@ class FunctionFactory:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_topological_sort_batch(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
func = vmap(func)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
@@ -201,6 +211,8 @@ class FunctionFactory:
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort_batch', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_batch(self, n):
|
||||
key = ('topological_sort_batch', n)
|
||||
if key not in self.compiled_function:
|
||||
@@ -215,32 +227,10 @@ class FunctionFactory:
|
||||
)
|
||||
self.single_forward_with_args = func
|
||||
|
||||
def compile_single_forward(self, n):
|
||||
"""
|
||||
single input for a genome
|
||||
:param n:
|
||||
:return:
|
||||
"""
|
||||
func = self.single_forward_with_args
|
||||
inputs_lower = np.zeros((self.num_inputs,))
|
||||
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('single_forward', n)] = func
|
||||
|
||||
def compile_pop_forward(self, n):
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(None, 0, 0, 0))
|
||||
|
||||
inputs_lower = np.zeros((self.num_inputs,))
|
||||
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('pop_forward', n)] = func
|
||||
|
||||
def compile_batch_forward(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None))
|
||||
|
||||
@@ -251,19 +241,19 @@ class FunctionFactory:
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_forward(self, n):
|
||||
key = ('batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_forward(n)
|
||||
if self.debug:
|
||||
def debug_batch_forward(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_batch_forward
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_pop_batch_forward(self, n):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
||||
@@ -276,25 +266,24 @@ class FunctionFactory:
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('pop_batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_pop_batch_forward(self, n):
|
||||
key = ('pop_batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_pop_batch_forward(n)
|
||||
if self.debug:
|
||||
def debug_pop_batch_forward(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_pop_batch_forward
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||
ts = self.create_topological_sort_batch(n)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||
|
||||
# for connections with enabled is false, set weight to 0)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||
# print(pop_cal_seqs)
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
@@ -314,6 +303,9 @@ class FunctionFactory:
|
||||
return debug_forward
|
||||
|
||||
def compile_batch_unflatten_connections(self, n, c):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = unflatten_connections
|
||||
func = vmap(func)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
@@ -321,14 +313,11 @@ class FunctionFactory:
|
||||
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_unflatten_connections(self, n, c):
|
||||
key = ('batch_unflatten_connections', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_unflatten_connections(n, c)
|
||||
if self.debug:
|
||||
def debug_batch_unflatten_connections(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_batch_unflatten_connections
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
return self.compiled_function[key]
|
||||
|
||||
Reference in New Issue
Block a user