""" Lowers, compiles, and creates functions used in the NEAT pipeline. """ from functools import partial import numpy as np from jax import jit, vmap from .genome import act_name2key, agg_name2key from .genome.genome import initialize_genomes from .genome.mutate import mutate from .genome.distance import distance from .genome.crossover import crossover class FunctionFactory: def __init__(self, config, debug=False): self.config = config self.debug = debug self.init_N = config.basic.init_maximum_nodes self.expand_coe = config.basic.expands_coe self.precompile_times = config.basic.pre_compile_times self.compiled_function = {} self.load_config_vals(config) self.precompile() pass def load_config_vals(self, config): self.pop_size = config.neat.population.pop_size self.init_N = config.basic.init_maximum_nodes self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient self.num_inputs = config.basic.num_inputs self.num_outputs = config.basic.num_outputs self.input_idx = np.arange(self.num_inputs) self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs) bias = config.neat.gene.bias self.bias_mean = bias.init_mean self.bias_std = bias.init_stdev self.bias_mutate_strength = bias.mutate_power self.bias_mutate_rate = bias.mutate_rate self.bias_replace_rate = bias.replace_rate response = config.neat.gene.response self.response_mean = response.init_mean self.response_std = response.init_stdev self.response_mutate_strength = response.mutate_power self.response_mutate_rate = response.mutate_rate self.response_replace_rate = response.replace_rate weight = config.neat.gene.weight self.weight_mean = weight.init_mean self.weight_std = weight.init_stdev self.weight_mutate_strength = weight.mutate_power self.weight_mutate_rate = weight.mutate_rate self.weight_replace_rate = weight.replace_rate activation = config.neat.gene.activation self.act_default = act_name2key[activation.default] self.act_list = np.array([act_name2key[name] for name in activation.options]) self.act_replace_rate = activation.mutate_rate aggregation = config.neat.gene.aggregation self.agg_default = agg_name2key[aggregation.default] self.agg_list = np.array([agg_name2key[name] for name in aggregation.options]) self.agg_replace_rate = aggregation.mutate_rate enabled = config.neat.gene.enabled self.enabled_reverse_rate = enabled.mutate_rate genome = config.neat.genome self.add_node_rate = genome.node_add_prob self.delete_node_rate = genome.node_delete_prob self.add_connection_rate = genome.conn_add_prob self.delete_connection_rate = genome.conn_delete_prob self.single_structure_mutate = genome.single_structural_mutation def create_initialize(self): func = partial( initialize_genomes, pop_size=self.pop_size, N=self.init_N, num_inputs=self.num_inputs, num_outputs=self.num_outputs, default_bias=self.bias_mean, default_response=self.response_mean, default_act=self.act_default, default_agg=self.agg_default, default_weight=self.weight_mean ) if self.debug: return lambda *args: func(*args) else: return func def precompile(self): self.create_mutate_with_args() self.create_distance_with_args() self.create_crossover_with_args() n = self.init_N print("start precompile") for _ in range(self.precompile_times): self.compile_mutate(n) self.compile_distance(n) self.compile_crossover(n) n = int(self.expand_coe * n) print("end precompile") def create_mutate_with_args(self): func = partial( mutate, input_idx=self.input_idx, output_idx=self.output_idx, bias_mean=self.bias_mean, bias_std=self.bias_std, bias_mutate_strength=self.bias_mutate_strength, bias_mutate_rate=self.bias_mutate_rate, bias_replace_rate=self.bias_replace_rate, response_mean=self.response_mean, response_std=self.response_std, response_mutate_strength=self.response_mutate_strength, response_mutate_rate=self.response_mutate_rate, response_replace_rate=self.response_replace_rate, weight_mean=self.weight_mean, weight_std=self.weight_std, weight_mutate_strength=self.weight_mutate_strength, weight_mutate_rate=self.weight_mutate_rate, weight_replace_rate=self.weight_replace_rate, act_default=self.act_default, act_list=self.act_list, act_replace_rate=self.act_replace_rate, agg_default=self.agg_default, agg_list=self.agg_list, agg_replace_rate=self.agg_replace_rate, enabled_reverse_rate=self.enabled_reverse_rate, add_node_rate=self.add_node_rate, delete_node_rate=self.delete_node_rate, add_connection_rate=self.add_connection_rate, delete_connection_rate=self.delete_connection_rate, single_structure_mutate=self.single_structure_mutate ) self.mutate_with_args = func def compile_mutate(self, n): func = self.mutate_with_args rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) nodes_lower = np.zeros((self.pop_size, n, 5)) connections_lower = np.zeros((self.pop_size, 2, n, n)) new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32) batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile() self.compiled_function[('mutate', n)] = batched_mutate_func def create_mutate(self, n): key = ('mutate', n) if key not in self.compiled_function: self.compile_mutate(n) return self.compiled_function[key] def create_distance_with_args(self): func = partial( distance, disjoint_coe=self.disjoint_coe, compatibility_coe=self.compatibility_coe ) self.distance_with_args = func def compile_distance(self, n): func = self.distance_with_args o2o_nodes1_lower = np.zeros((n, 5)) o2o_connections1_lower = np.zeros((2, n, n)) o2o_nodes2_lower = np.zeros((n, 5)) o2o_connections2_lower = np.zeros((2, n, n)) o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, o2o_nodes2_lower, o2o_connections2_lower).compile() o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n)) o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, o2m_nodes2_lower, o2m_connections2_lower).compile() self.compiled_function[('o2o_distance', n)] = o2o_distance self.compiled_function[('o2m_distance', n)] = o2m_distance def create_distance(self, n): key1, key2 = ('o2o_distance', n), ('o2m_distance', n) if key1 not in self.compiled_function: self.compile_distance(n) return self.compiled_function[key1], self.compiled_function[key2] def create_crossover_with_args(self): self.crossover_with_args = crossover def compile_crossover(self, n): func = self.crossover_with_args randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) nodes1_lower = np.zeros((self.pop_size, n, 5)) connections1_lower = np.zeros((self.pop_size, 2, n, n)) nodes2_lower = np.zeros((self.pop_size, n, 5)) connections2_lower = np.zeros((self.pop_size, 2, n, n)) func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() self.compiled_function[('crossover', n)] = func def create_crossover(self, n): key = ('crossover', n) if key not in self.compiled_function: self.compile_crossover(n) return self.compiled_function[key]