add function to put **all** compilation at the beginning of the execution.

This commit is contained in:
wls2002
2023-05-09 02:55:47 +08:00
parent 1f2327bbd6
commit 0fdc856f2d
6 changed files with 183 additions and 75 deletions

View File

@@ -7,6 +7,7 @@ import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import topological_sort, forward_single
class FunctionFactory:
@@ -24,6 +25,8 @@ class FunctionFactory:
pass
def load_config_vals(self, config):
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
@@ -98,12 +101,17 @@ class FunctionFactory:
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
n = self.init_N
print("start precompile")
for _ in range(self.precompile_times):
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
self.compile_topological_sort(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n)
print("end precompile")
@@ -209,3 +217,99 @@ class FunctionFactory:
if key not in self.compiled_function:
self.compile_crossover(n)
return self.compiled_function[key]
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
def compile_topological_sort(self, n):
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
def create_topological_sort(self, n):
key = ('topological_sort', n)
if key not in self.compiled_function:
self.compile_topological_sort(n)
return self.compiled_function[key]
def create_single_forward_with_args(self):
func = partial(
forward_single,
input_idx=self.input_idx,
output_idx=self.output_idx
)
self.single_forward_with_args = func
def compile_single_forward(self, n):
"""
single input for a genome
:param n:
:return:
"""
func = self.single_forward_with_args
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('single_forward', n)] = func
def compile_pop_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(None, 0, 0, 0))
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_forward', n)] = func
def compile_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None))
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
def compile_pop_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_batch_forward', n)] = func
def create_pop_batch_forward(self, n):
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
return self.compiled_function[key]
def ask(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
ts = self.create_topological_sort(n)
pop_cal_seqs = ts(pop_nodes, pop_connections)
forward_func = self.create_pop_batch_forward(n)
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
# return partial(
# forward_func,
# cal_seqs=pop_cal_seqs,
# nodes=pop_nodes,
# connections=pop_connections
# )