""" Lowers, compiles, and creates functions used in the NEAT pipeline. """ from functools import partial import time import numpy as np from jax import jit, vmap from .genome import act_name2key, agg_name2key, initialize_genomes from .genome import topological_sort, forward_single, unflatten_connections from .population import create_next_generation_then_speciate class FunctionFactory: def __init__(self, config): self.config = config self.expand_coe = config.basic.expands_coe self.precompile_times = config.basic.pre_compile_times self.compiled_function = {} self.compile_time = 0 self.load_config_vals(config) self.create_topological_sort_with_args() self.create_single_forward_with_args() self.create_update_speciate_with_args() def load_config_vals(self, config): self.compatibility_threshold = self.config.neat.species.compatibility_threshold self.problem_batch = config.basic.problem_batch self.pop_size = config.neat.population.pop_size self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient self.num_inputs = config.basic.num_inputs self.num_outputs = config.basic.num_outputs self.input_idx = np.arange(self.num_inputs) self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs) bias = config.neat.gene.bias self.bias_mean = bias.init_mean self.bias_std = bias.init_stdev self.bias_mutate_strength = bias.mutate_power self.bias_mutate_rate = bias.mutate_rate self.bias_replace_rate = bias.replace_rate response = config.neat.gene.response self.response_mean = response.init_mean self.response_std = response.init_stdev self.response_mutate_strength = response.mutate_power self.response_mutate_rate = response.mutate_rate self.response_replace_rate = response.replace_rate weight = config.neat.gene.weight self.weight_mean = weight.init_mean self.weight_std = weight.init_stdev self.weight_mutate_strength = weight.mutate_power self.weight_mutate_rate = weight.mutate_rate self.weight_replace_rate = weight.replace_rate activation = config.neat.gene.activation self.act_default = act_name2key[activation.default] self.act_list = np.array([act_name2key[name] for name in activation.options]) self.act_replace_rate = activation.mutate_rate aggregation = config.neat.gene.aggregation self.agg_default = agg_name2key[aggregation.default] self.agg_list = np.array([agg_name2key[name] for name in aggregation.options]) self.agg_replace_rate = aggregation.mutate_rate enabled = config.neat.gene.enabled self.enabled_reverse_rate = enabled.mutate_rate genome = config.neat.genome self.add_node_rate = genome.node_add_prob self.delete_node_rate = genome.node_delete_prob self.add_connection_rate = genome.conn_add_prob self.delete_connection_rate = genome.conn_delete_prob self.single_structure_mutate = genome.single_structural_mutation def create_initialize(self, N, C): func = partial( initialize_genomes, pop_size=self.pop_size, N=N, C=C, num_inputs=self.num_inputs, num_outputs=self.num_outputs, default_bias=self.bias_mean, default_response=self.response_mean, default_act=self.act_default, default_agg=self.agg_default, default_weight=self.weight_mean ) return func def create_update_speciate_with_args(self): species_kwargs = { "disjoint_coe": self.disjoint_coe, "compatibility_coe": self.compatibility_coe, "compatibility_threshold": self.compatibility_threshold } mutate_kwargs = { "input_idx": self.input_idx, "output_idx": self.output_idx, "bias_mean": self.bias_mean, "bias_std": self.bias_std, "bias_mutate_strength": self.bias_mutate_strength, "bias_mutate_rate": self.bias_mutate_rate, "bias_replace_rate": self.bias_replace_rate, "response_mean": self.response_mean, "response_std": self.response_std, "response_mutate_strength": self.response_mutate_strength, "response_mutate_rate": self.response_mutate_rate, "response_replace_rate": self.response_replace_rate, "weight_mean": self.weight_mean, "weight_std": self.weight_std, "weight_mutate_strength": self.weight_mutate_strength, "weight_mutate_rate": self.weight_mutate_rate, "weight_replace_rate": self.weight_replace_rate, "act_default": self.act_default, "act_list": self.act_list, "act_replace_rate": self.act_replace_rate, "agg_default": self.agg_default, "agg_list": self.agg_list, "agg_replace_rate": self.agg_replace_rate, "enabled_reverse_rate": self.enabled_reverse_rate, "add_node_rate": self.add_node_rate, "delete_node_rate": self.delete_node_rate, "add_connection_rate": self.add_connection_rate, "delete_connection_rate": self.delete_connection_rate, } self.update_speciate_with_args = partial( create_next_generation_then_speciate, species_kwargs=species_kwargs, mutate_kwargs=mutate_kwargs ) def create_update_speciate(self, N, C, S): key = ("update_speciate", N, C, S) if key not in self.compiled_function: self.compile_update_speciate(N, C, S) return self.compiled_function[key] def compile_update_speciate(self, N, C, S): s = time.time() func = self.update_speciate_with_args randkey_lower = np.zeros((2,), dtype=np.uint32) pop_nodes_lower = np.zeros((self.pop_size, N, 5)) pop_cons_lower = np.zeros((self.pop_size, C, 4)) winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32) loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32) elite_mask_lower = np.zeros((self.pop_size,), dtype=bool) new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32) pre_spe_center_nodes_lower = np.zeros((S, N, 5)) pre_spe_center_cons_lower = np.zeros((S, C, 4)) species_keys = np.zeros((S,), dtype=np.int32) new_species_keys_lower = 0 compiled_func = jit(func).lower( randkey_lower, pop_nodes_lower, pop_cons_lower, winner_part_lower, loser_part_lower, elite_mask_lower, new_node_keys_start_lower, pre_spe_center_nodes_lower, pre_spe_center_cons_lower, species_keys, new_species_keys_lower, ).compile() self.compiled_function[("update_speciate", N, C, S)] = compiled_func self.compile_time += time.time() - s def create_topological_sort_with_args(self): self.topological_sort_with_args = topological_sort def compile_topological_sort(self, n): s = time.time() func = self.topological_sort_with_args nodes_lower = np.zeros((n, 5)) connections_lower = np.zeros((2, n, n)) func = jit(func).lower(nodes_lower, connections_lower).compile() self.compiled_function[('topological_sort', n)] = func self.compile_time += time.time() - s def create_topological_sort(self, n): key = ('topological_sort', n) if key not in self.compiled_function: self.compile_topological_sort(n) return self.compiled_function[key] def compile_topological_sort_batch(self, n): s = time.time() func = self.topological_sort_with_args func = vmap(func) nodes_lower = np.zeros((self.pop_size, n, 5)) connections_lower = np.zeros((self.pop_size, 2, n, n)) func = jit(func).lower(nodes_lower, connections_lower).compile() self.compiled_function[('topological_sort_batch', n)] = func self.compile_time += time.time() - s def create_topological_sort_batch(self, n): key = ('topological_sort_batch', n) if key not in self.compiled_function: self.compile_topological_sort_batch(n) return self.compiled_function[key] def create_single_forward_with_args(self): func = partial( forward_single, input_idx=self.input_idx, output_idx=self.output_idx ) self.single_forward_with_args = func def compile_batch_forward(self, n): s = time.time() func = self.single_forward_with_args func = vmap(func, in_axes=(0, None, None, None)) inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) cal_seqs_lower = np.zeros((n,), dtype=np.int32) nodes_lower = np.zeros((n, 5)) connections_lower = np.zeros((2, n, n)) func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() self.compiled_function[('batch_forward', n)] = func self.compile_time += time.time() - s def create_batch_forward(self, n): key = ('batch_forward', n) if key not in self.compiled_function: self.compile_batch_forward(n) return self.compiled_function[key] def compile_pop_batch_forward(self, n): s = time.time() func = self.single_forward_with_args func = vmap(func, in_axes=(0, None, None, None)) # batch_forward func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32) nodes_lower = np.zeros((self.pop_size, n, 5)) connections_lower = np.zeros((self.pop_size, 2, n, n)) func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() self.compiled_function[('pop_batch_forward', n)] = func self.compile_time += time.time() - s def create_pop_batch_forward(self, n): key = ('pop_batch_forward', n) if key not in self.compiled_function: self.compile_pop_batch_forward(n) return self.compiled_function[key] def ask_pop_batch_forward(self, pop_nodes, pop_cons): n, c = pop_nodes.shape[1], pop_cons.shape[1] batch_unflatten_func = self.create_batch_unflatten_connections(n, c) pop_cons = batch_unflatten_func(pop_nodes, pop_cons) ts = self.create_topological_sort_batch(n) # for connections with enabled is false, set weight to 0) pop_cal_seqs = ts(pop_nodes, pop_cons) # print(pop_cal_seqs) forward_func = self.create_pop_batch_forward(n) def debug_forward(inputs): return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons) return debug_forward def ask_batch_forward(self, nodes, connections): n = nodes.shape[0] ts = self.create_topological_sort(n) cal_seqs = ts(nodes, connections) forward_func = self.create_batch_forward(n) def debug_forward(inputs): return forward_func(inputs, cal_seqs, nodes, connections) return debug_forward def compile_batch_unflatten_connections(self, n, c): s = time.time() func = unflatten_connections func = vmap(func) pop_nodes_lower = np.zeros((self.pop_size, n, 5)) pop_connections_lower = np.zeros((self.pop_size, c, 4)) func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile() self.compiled_function[('batch_unflatten_connections', n, c)] = func self.compile_time += time.time() - s def create_batch_unflatten_connections(self, n, c): key = ('batch_unflatten_connections', n, c) if key not in self.compiled_function: self.compile_batch_unflatten_connections(n, c) return self.compiled_function[key]