refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

faster, faster and faster!
This commit is contained in:
wls2002
2023-05-12 00:57:55 +08:00
parent e5fc1167d9
commit 47b1a1dbb2
16 changed files with 363 additions and 419 deletions

View File

@@ -8,7 +8,7 @@ import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import topological_sort, forward_single
from .genome import topological_sort, forward_single, unflatten_connections
class FunctionFactory:
@@ -17,19 +17,18 @@ class FunctionFactory:
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.init_C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.load_config_vals(config)
self.precompile()
pass
def load_config_vals(self, config):
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
@@ -85,6 +84,7 @@ class FunctionFactory:
initialize_genomes,
pop_size=self.pop_size,
N=self.init_N,
C=self.init_C,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
@@ -107,24 +107,24 @@ class FunctionFactory:
self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
n = self.init_N
print("start precompile")
for _ in range(self.precompile_times):
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
self.compile_topological_sort_batch(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n)
# precompile other functions used in jax
key = jax.random.PRNGKey(0)
_ = jax.random.split(key, 3)
_ = jax.random.split(key, self.pop_size * 2)
_ = jax.random.split(key, self.pop_size)
print("end precompile")
#
# n, c = self.init_N, self.init_C
# print("start precompile")
# for _ in range(self.precompile_times):
# self.compile_mutate(n)
# self.compile_distance(n)
# self.compile_crossover(n)
# self.compile_topological_sort_batch(n)
# self.compile_pop_batch_forward(n)
# n = int(self.expand_coe * n)
#
# # precompile other functions used in jax
# key = jax.random.PRNGKey(0)
# _ = jax.random.split(key, 3)
# _ = jax.random.split(key, self.pop_size * 2)
# _ = jax.random.split(key, self.pop_size)
#
# print("end precompile")
def create_mutate_with_args(self):
func = partial(
@@ -161,20 +161,20 @@ class FunctionFactory:
)
self.mutate_with_args = func
def compile_mutate(self, n):
def compile_mutate(self, n, c):
func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
connections_lower = np.zeros((self.pop_size, c, 4))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n)] = batched_mutate_func
self.compiled_function[('mutate', n, c)] = batched_mutate_func
def create_mutate(self, n):
key = ('mutate', n)
def create_mutate(self, n, c):
key = ('mutate', n, c)
if key not in self.compiled_function:
self.compile_mutate(n)
self.compile_mutate(n, c)
if self.debug:
def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
@@ -192,28 +192,28 @@ class FunctionFactory:
)
self.distance_with_args = func
def compile_distance(self, n):
def compile_distance(self, n, c):
func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((2, n, n))
o2o_connections1_lower = np.zeros((c, 4))
o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((2, n, n))
o2o_connections2_lower = np.zeros((c, 4))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower,
o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n)] = o2o_distance
self.compiled_function[('o2m_distance', n)] = o2m_distance
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
def create_distance(self, n):
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
def create_distance(self, n, c):
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
if key1 not in self.compiled_function:
self.compile_distance(n)
self.compile_distance(n, c)
if self.debug:
def debug_o2o_distance(*args):
@@ -229,21 +229,21 @@ class FunctionFactory:
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n):
def compile_crossover(self, n, c):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, 2, n, n))
connections1_lower = np.zeros((self.pop_size, c, 4))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, 2, n, n))
connections2_lower = np.zeros((self.pop_size, c, 4))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n)] = func
self.compiled_function[('crossover', n, c)] = func
def create_crossover(self, n):
key = ('crossover', n)
def create_crossover(self, n, c):
key = ('crossover', n, c)
if key not in self.compiled_function:
self.compile_crossover(n)
self.compile_crossover(n, c)
if self.debug:
def debug_crossover(*args):
@@ -365,15 +365,17 @@ class FunctionFactory:
else:
return self.compiled_function[key]
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
n, c = pop_nodes.shape[1], pop_cons.shape[1]
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
ts = self.create_topological_sort_batch(n)
pop_cal_seqs = ts(pop_nodes, pop_connections)
pop_cal_seqs = ts(pop_nodes, pop_cons)
forward_func = self.create_pop_batch_forward(n)
def debug_forward(inputs):
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
return debug_forward
@@ -387,3 +389,23 @@ class FunctionFactory:
return forward_func(inputs, cal_seqs, nodes, connections)
return debug_forward
def compile_batch_unflatten_connections(self, n, c):
func = unflatten_connections
func = vmap(func)
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
pop_connections_lower = np.zeros((self.pop_size, c, 4))
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
self.compiled_function[('batch_unflatten_connections', n, c)] = func
def create_batch_unflatten_connections(self, n, c):
key = ('batch_unflatten_connections', n, c)
if key not in self.compiled_function:
self.compile_batch_unflatten_connections(n, c)
if self.debug:
def debug_batch_unflatten_connections(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_batch_unflatten_connections
else:
return self.compiled_function[key]