create function_factory.py, use to manage functions

This commit is contained in:
wls2002
2023-05-09 01:49:43 +08:00
parent ee6bb01eff
commit f63a0c447b
7 changed files with 231 additions and 18 deletions

View File

@@ -0,0 +1,215 @@
"""
Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key
from .genome.genome import initialize_genomes
from .genome.mutate import mutate
from .genome.distance import distance
from .genome.crossover import crossover
class FunctionFactory:
def __init__(self, config, debug=False):
self.config = config
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.load_config_vals(config)
self.precompile()
pass
def load_config_vals(self, config):
self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
self.num_inputs = config.basic.num_inputs
self.num_outputs = config.basic.num_outputs
self.input_idx = np.arange(self.num_inputs)
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
bias = config.neat.gene.bias
self.bias_mean = bias.init_mean
self.bias_std = bias.init_stdev
self.bias_mutate_strength = bias.mutate_power
self.bias_mutate_rate = bias.mutate_rate
self.bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
self.response_mean = response.init_mean
self.response_std = response.init_stdev
self.response_mutate_strength = response.mutate_power
self.response_mutate_rate = response.mutate_rate
self.response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
self.weight_mean = weight.init_mean
self.weight_std = weight.init_stdev
self.weight_mutate_strength = weight.mutate_power
self.weight_mutate_rate = weight.mutate_rate
self.weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
self.act_default = act_name2key[activation.default]
self.act_list = np.array([act_name2key[name] for name in activation.options])
self.act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
self.agg_default = agg_name2key[aggregation.default]
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
self.agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
self.enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
self.add_node_rate = genome.node_add_prob
self.delete_node_rate = genome.node_delete_prob
self.add_connection_rate = genome.conn_add_prob
self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self):
func = partial(
initialize_genomes,
pop_size=self.pop_size,
N=self.init_N,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
default_response=self.response_mean,
default_act=self.act_default,
default_agg=self.agg_default,
default_weight=self.weight_mean
)
if self.debug:
return lambda *args: func(*args)
else:
return func
def precompile(self):
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
n = self.init_N
print("start precompile")
for _ in range(self.precompile_times):
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
n = int(self.expand_coe * n)
print("end precompile")
def create_mutate_with_args(self):
func = partial(
mutate,
input_idx=self.input_idx,
output_idx=self.output_idx,
bias_mean=self.bias_mean,
bias_std=self.bias_std,
bias_mutate_strength=self.bias_mutate_strength,
bias_mutate_rate=self.bias_mutate_rate,
bias_replace_rate=self.bias_replace_rate,
response_mean=self.response_mean,
response_std=self.response_std,
response_mutate_strength=self.response_mutate_strength,
response_mutate_rate=self.response_mutate_rate,
response_replace_rate=self.response_replace_rate,
weight_mean=self.weight_mean,
weight_std=self.weight_std,
weight_mutate_strength=self.weight_mutate_strength,
weight_mutate_rate=self.weight_mutate_rate,
weight_replace_rate=self.weight_replace_rate,
act_default=self.act_default,
act_list=self.act_list,
act_replace_rate=self.act_replace_rate,
agg_default=self.agg_default,
agg_list=self.agg_list,
agg_replace_rate=self.agg_replace_rate,
enabled_reverse_rate=self.enabled_reverse_rate,
add_node_rate=self.add_node_rate,
delete_node_rate=self.delete_node_rate,
add_connection_rate=self.add_connection_rate,
delete_connection_rate=self.delete_connection_rate,
single_structure_mutate=self.single_structure_mutate
)
self.mutate_with_args = func
def compile_mutate(self, n):
func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n)] = batched_mutate_func
def create_mutate(self, n):
key = ('mutate', n)
if key not in self.compiled_function:
self.compile_mutate(n)
return self.compiled_function[key]
def create_distance_with_args(self):
func = partial(
distance,
disjoint_coe=self.disjoint_coe,
compatibility_coe=self.compatibility_coe
)
self.distance_with_args = func
def compile_distance(self, n):
func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((2, n, n))
o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((2, n, n))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower,
o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n)] = o2o_distance
self.compiled_function[('o2m_distance', n)] = o2m_distance
def create_distance(self, n):
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
if key1 not in self.compiled_function:
self.compile_distance(n)
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, 2, n, n))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n)] = func
def create_crossover(self, n):
key = ('crossover', n)
if key not in self.compiled_function:
self.compile_crossover(n)
return self.compiled_function[key]