This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -2,30 +2,34 @@
Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import time
import jax.random
import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import act_name2key, agg_name2key, initialize_genomes
from .genome import topological_sort, forward_single, unflatten_connections
from .population import create_next_generation_then_speciate
class FunctionFactory:
def __init__(self, config, debug=False):
def __init__(self, config):
self.config = config
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.init_C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.time_cost = {}
self.load_config_vals(config)
self.precompile()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
self.create_update_speciate_with_args()
def load_config_vals(self, config):
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
@@ -79,12 +83,12 @@ class FunctionFactory:
self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self):
def create_initialize(self, N, C):
func = partial(
initialize_genomes,
pop_size=self.pop_size,
N=self.init_N,
C=self.init_C,
N=N,
C=C,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
@@ -93,166 +97,85 @@ class FunctionFactory:
default_agg=self.agg_default,
default_weight=self.weight_mean
)
if self.debug:
def debug_initialize(*args):
return func(*args)
return func
return debug_initialize
else:
return func
def create_update_speciate_with_args(self):
species_kwargs = {
"disjoint_coe": self.disjoint_coe,
"compatibility_coe": self.compatibility_coe,
"compatibility_threshold": self.compatibility_threshold
}
def precompile(self):
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
#
# n, c = self.init_N, self.init_C
# print("start precompile")
# for _ in range(self.precompile_times):
# self.compile_mutate(n)
# self.compile_distance(n)
# self.compile_crossover(n)
# self.compile_topological_sort_batch(n)
# self.compile_pop_batch_forward(n)
# n = int(self.expand_coe * n)
#
# # precompile other functions used in jax
# key = jax.random.PRNGKey(0)
# _ = jax.random.split(key, 3)
# _ = jax.random.split(key, self.pop_size * 2)
# _ = jax.random.split(key, self.pop_size)
#
# print("end precompile")
mutate_kwargs = {
"input_idx": self.input_idx,
"output_idx": self.output_idx,
"bias_mean": self.bias_mean,
"bias_std": self.bias_std,
"bias_mutate_strength": self.bias_mutate_strength,
"bias_mutate_rate": self.bias_mutate_rate,
"bias_replace_rate": self.bias_replace_rate,
"response_mean": self.response_mean,
"response_std": self.response_std,
"response_mutate_strength": self.response_mutate_strength,
"response_mutate_rate": self.response_mutate_rate,
"response_replace_rate": self.response_replace_rate,
"weight_mean": self.weight_mean,
"weight_std": self.weight_std,
"weight_mutate_strength": self.weight_mutate_strength,
"weight_mutate_rate": self.weight_mutate_rate,
"weight_replace_rate": self.weight_replace_rate,
"act_default": self.act_default,
"act_list": self.act_list,
"act_replace_rate": self.act_replace_rate,
"agg_default": self.agg_default,
"agg_list": self.agg_list,
"agg_replace_rate": self.agg_replace_rate,
"enabled_reverse_rate": self.enabled_reverse_rate,
"add_node_rate": self.add_node_rate,
"delete_node_rate": self.delete_node_rate,
"add_connection_rate": self.add_connection_rate,
"delete_connection_rate": self.delete_connection_rate,
}
def create_mutate_with_args(self):
func = partial(
mutate,
input_idx=self.input_idx,
output_idx=self.output_idx,
bias_mean=self.bias_mean,
bias_std=self.bias_std,
bias_mutate_strength=self.bias_mutate_strength,
bias_mutate_rate=self.bias_mutate_rate,
bias_replace_rate=self.bias_replace_rate,
response_mean=self.response_mean,
response_std=self.response_std,
response_mutate_strength=self.response_mutate_strength,
response_mutate_rate=self.response_mutate_rate,
response_replace_rate=self.response_replace_rate,
weight_mean=self.weight_mean,
weight_std=self.weight_std,
weight_mutate_strength=self.weight_mutate_strength,
weight_mutate_rate=self.weight_mutate_rate,
weight_replace_rate=self.weight_replace_rate,
act_default=self.act_default,
act_list=self.act_list,
act_replace_rate=self.act_replace_rate,
agg_default=self.agg_default,
agg_list=self.agg_list,
agg_replace_rate=self.agg_replace_rate,
enabled_reverse_rate=self.enabled_reverse_rate,
add_node_rate=self.add_node_rate,
delete_node_rate=self.delete_node_rate,
add_connection_rate=self.add_connection_rate,
delete_connection_rate=self.delete_connection_rate,
single_structure_mutate=self.single_structure_mutate
self.update_speciate_with_args = partial(
create_next_generation_then_speciate,
species_kwargs=species_kwargs,
mutate_kwargs=mutate_kwargs
)
self.mutate_with_args = func
def compile_mutate(self, n, c):
func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, c, 4))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n, c)] = batched_mutate_func
def create_mutate(self, n, c):
key = ('mutate', n, c)
def create_update_speciate(self, N, C, S):
key = ("update_speciate", N, C, S)
if key not in self.compiled_function:
self.compile_mutate(n, c)
if self.debug:
def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
self.compile_update_speciate(N, C, S)
return self.compiled_function[key]
return debug_mutate
else:
return self.compiled_function[key]
def create_distance_with_args(self):
func = partial(
distance,
disjoint_coe=self.disjoint_coe,
compatibility_coe=self.compatibility_coe
)
self.distance_with_args = func
def compile_distance(self, n, c):
func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((c, 4))
o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((c, 4))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower,
o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
def create_distance(self, n, c):
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
if key1 not in self.compiled_function:
self.compile_distance(n, c)
if self.debug:
def debug_o2o_distance(*args):
return self.compiled_function[key1](*args).block_until_ready()
def debug_o2m_distance(*args):
return self.compiled_function[key2](*args).block_until_ready()
return debug_o2o_distance, debug_o2m_distance
else:
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n, c):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, c, 4))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, c, 4))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n, c)] = func
def create_crossover(self, n, c):
key = ('crossover', n, c)
if key not in self.compiled_function:
self.compile_crossover(n, c)
if self.debug:
def debug_crossover(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_crossover
else:
return self.compiled_function[key]
def compile_update_speciate(self, N, C, S):
func = self.update_speciate_with_args
randkey_lower = np.zeros((2,), dtype=np.uint32)
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
pop_cons_lower = np.zeros((self.pop_size, C, 4))
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
pre_spe_center_cons_lower = np.zeros((S, C, 4))
species_keys = np.zeros((S,), dtype=np.int32)
new_species_keys_lower = 0
compiled_func = jit(func).lower(
randkey_lower,
pop_nodes_lower,
pop_cons_lower,
winner_part_lower,
loser_part_lower,
elite_mask_lower,
new_node_keys_start_lower,
pre_spe_center_nodes_lower,
pre_spe_center_cons_lower,
species_keys,
new_species_keys_lower,
).compile()
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort