324 lines
12 KiB
Python
324 lines
12 KiB
Python
"""
|
|
Lowers, compiles, and creates functions used in the NEAT pipeline.
|
|
"""
|
|
from functools import partial
|
|
import time
|
|
|
|
import numpy as np
|
|
from jax import jit, vmap
|
|
|
|
from .genome import act_name2key, agg_name2key, initialize_genomes
|
|
from .genome import topological_sort, forward_single, unflatten_connections
|
|
from .population import create_next_generation_then_speciate
|
|
|
|
|
|
class FunctionFactory:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
self.expand_coe = config.basic.expands_coe
|
|
self.precompile_times = config.basic.pre_compile_times
|
|
self.compiled_function = {}
|
|
self.compile_time = 0
|
|
|
|
self.load_config_vals(config)
|
|
|
|
self.create_topological_sort_with_args()
|
|
self.create_single_forward_with_args()
|
|
self.create_update_speciate_with_args()
|
|
|
|
def load_config_vals(self, config):
|
|
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
|
|
|
self.problem_batch = config.basic.problem_batch
|
|
|
|
self.pop_size = config.neat.population.pop_size
|
|
|
|
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
|
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
|
|
|
self.num_inputs = config.basic.num_inputs
|
|
self.num_outputs = config.basic.num_outputs
|
|
self.input_idx = np.arange(self.num_inputs)
|
|
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
|
|
|
|
bias = config.neat.gene.bias
|
|
self.bias_mean = bias.init_mean
|
|
self.bias_std = bias.init_stdev
|
|
self.bias_mutate_strength = bias.mutate_power
|
|
self.bias_mutate_rate = bias.mutate_rate
|
|
self.bias_replace_rate = bias.replace_rate
|
|
|
|
response = config.neat.gene.response
|
|
self.response_mean = response.init_mean
|
|
self.response_std = response.init_stdev
|
|
self.response_mutate_strength = response.mutate_power
|
|
self.response_mutate_rate = response.mutate_rate
|
|
self.response_replace_rate = response.replace_rate
|
|
|
|
weight = config.neat.gene.weight
|
|
self.weight_mean = weight.init_mean
|
|
self.weight_std = weight.init_stdev
|
|
self.weight_mutate_strength = weight.mutate_power
|
|
self.weight_mutate_rate = weight.mutate_rate
|
|
self.weight_replace_rate = weight.replace_rate
|
|
|
|
activation = config.neat.gene.activation
|
|
self.act_default = act_name2key[activation.default]
|
|
self.act_list = np.array([act_name2key[name] for name in activation.options])
|
|
self.act_replace_rate = activation.mutate_rate
|
|
|
|
aggregation = config.neat.gene.aggregation
|
|
self.agg_default = agg_name2key[aggregation.default]
|
|
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
|
self.agg_replace_rate = aggregation.mutate_rate
|
|
|
|
enabled = config.neat.gene.enabled
|
|
self.enabled_reverse_rate = enabled.mutate_rate
|
|
|
|
genome = config.neat.genome
|
|
self.add_node_rate = genome.node_add_prob
|
|
self.delete_node_rate = genome.node_delete_prob
|
|
self.add_connection_rate = genome.conn_add_prob
|
|
self.delete_connection_rate = genome.conn_delete_prob
|
|
self.single_structure_mutate = genome.single_structural_mutation
|
|
|
|
def create_initialize(self, N, C):
|
|
func = partial(
|
|
initialize_genomes,
|
|
pop_size=self.pop_size,
|
|
N=N,
|
|
C=C,
|
|
num_inputs=self.num_inputs,
|
|
num_outputs=self.num_outputs,
|
|
default_bias=self.bias_mean,
|
|
default_response=self.response_mean,
|
|
default_act=self.act_default,
|
|
default_agg=self.agg_default,
|
|
default_weight=self.weight_mean
|
|
)
|
|
return func
|
|
|
|
def create_update_speciate_with_args(self):
|
|
species_kwargs = {
|
|
"disjoint_coe": self.disjoint_coe,
|
|
"compatibility_coe": self.compatibility_coe,
|
|
"compatibility_threshold": self.compatibility_threshold
|
|
}
|
|
|
|
mutate_kwargs = {
|
|
"input_idx": self.input_idx,
|
|
"output_idx": self.output_idx,
|
|
"bias_mean": self.bias_mean,
|
|
"bias_std": self.bias_std,
|
|
"bias_mutate_strength": self.bias_mutate_strength,
|
|
"bias_mutate_rate": self.bias_mutate_rate,
|
|
"bias_replace_rate": self.bias_replace_rate,
|
|
"response_mean": self.response_mean,
|
|
"response_std": self.response_std,
|
|
"response_mutate_strength": self.response_mutate_strength,
|
|
"response_mutate_rate": self.response_mutate_rate,
|
|
"response_replace_rate": self.response_replace_rate,
|
|
"weight_mean": self.weight_mean,
|
|
"weight_std": self.weight_std,
|
|
"weight_mutate_strength": self.weight_mutate_strength,
|
|
"weight_mutate_rate": self.weight_mutate_rate,
|
|
"weight_replace_rate": self.weight_replace_rate,
|
|
"act_default": self.act_default,
|
|
"act_list": self.act_list,
|
|
"act_replace_rate": self.act_replace_rate,
|
|
"agg_default": self.agg_default,
|
|
"agg_list": self.agg_list,
|
|
"agg_replace_rate": self.agg_replace_rate,
|
|
"enabled_reverse_rate": self.enabled_reverse_rate,
|
|
"add_node_rate": self.add_node_rate,
|
|
"delete_node_rate": self.delete_node_rate,
|
|
"add_connection_rate": self.add_connection_rate,
|
|
"delete_connection_rate": self.delete_connection_rate,
|
|
}
|
|
|
|
self.update_speciate_with_args = partial(
|
|
create_next_generation_then_speciate,
|
|
species_kwargs=species_kwargs,
|
|
mutate_kwargs=mutate_kwargs
|
|
)
|
|
|
|
def create_update_speciate(self, N, C, S):
|
|
key = ("update_speciate", N, C, S)
|
|
if key not in self.compiled_function:
|
|
self.compile_update_speciate(N, C, S)
|
|
return self.compiled_function[key]
|
|
|
|
def compile_update_speciate(self, N, C, S):
|
|
s = time.time()
|
|
|
|
func = self.update_speciate_with_args
|
|
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
|
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
|
pop_cons_lower = np.zeros((self.pop_size, C, 4))
|
|
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
|
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
|
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
|
|
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
|
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
|
|
pre_spe_center_cons_lower = np.zeros((S, C, 4))
|
|
species_keys = np.zeros((S,), dtype=np.int32)
|
|
new_species_keys_lower = 0
|
|
compiled_func = jit(func).lower(
|
|
randkey_lower,
|
|
pop_nodes_lower,
|
|
pop_cons_lower,
|
|
winner_part_lower,
|
|
loser_part_lower,
|
|
elite_mask_lower,
|
|
new_node_keys_start_lower,
|
|
pre_spe_center_nodes_lower,
|
|
pre_spe_center_cons_lower,
|
|
species_keys,
|
|
new_species_keys_lower,
|
|
).compile()
|
|
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_topological_sort_with_args(self):
|
|
self.topological_sort_with_args = topological_sort
|
|
|
|
def compile_topological_sort(self, n):
|
|
s = time.time()
|
|
|
|
func = self.topological_sort_with_args
|
|
nodes_lower = np.zeros((n, 5))
|
|
connections_lower = np.zeros((2, n, n))
|
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
|
self.compiled_function[('topological_sort', n)] = func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_topological_sort(self, n):
|
|
key = ('topological_sort', n)
|
|
if key not in self.compiled_function:
|
|
self.compile_topological_sort(n)
|
|
return self.compiled_function[key]
|
|
|
|
def compile_topological_sort_batch(self, n):
|
|
s = time.time()
|
|
|
|
func = self.topological_sort_with_args
|
|
func = vmap(func)
|
|
nodes_lower = np.zeros((self.pop_size, n, 5))
|
|
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
|
self.compiled_function[('topological_sort_batch', n)] = func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_topological_sort_batch(self, n):
|
|
key = ('topological_sort_batch', n)
|
|
if key not in self.compiled_function:
|
|
self.compile_topological_sort_batch(n)
|
|
return self.compiled_function[key]
|
|
|
|
def create_single_forward_with_args(self):
|
|
func = partial(
|
|
forward_single,
|
|
input_idx=self.input_idx,
|
|
output_idx=self.output_idx
|
|
)
|
|
self.single_forward_with_args = func
|
|
|
|
|
|
def compile_batch_forward(self, n):
|
|
s = time.time()
|
|
|
|
func = self.single_forward_with_args
|
|
func = vmap(func, in_axes=(0, None, None, None))
|
|
|
|
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
|
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
|
|
nodes_lower = np.zeros((n, 5))
|
|
connections_lower = np.zeros((2, n, n))
|
|
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
|
self.compiled_function[('batch_forward', n)] = func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_batch_forward(self, n):
|
|
key = ('batch_forward', n)
|
|
if key not in self.compiled_function:
|
|
self.compile_batch_forward(n)
|
|
|
|
return self.compiled_function[key]
|
|
|
|
def compile_pop_batch_forward(self, n):
|
|
|
|
s = time.time()
|
|
|
|
func = self.single_forward_with_args
|
|
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
|
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
|
|
|
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
|
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
|
|
nodes_lower = np.zeros((self.pop_size, n, 5))
|
|
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
|
|
|
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
|
self.compiled_function[('pop_batch_forward', n)] = func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_pop_batch_forward(self, n):
|
|
key = ('pop_batch_forward', n)
|
|
if key not in self.compiled_function:
|
|
self.compile_pop_batch_forward(n)
|
|
|
|
return self.compiled_function[key]
|
|
|
|
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
|
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
|
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
|
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
|
ts = self.create_topological_sort_batch(n)
|
|
|
|
# for connections with enabled is false, set weight to 0)
|
|
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
|
# print(pop_cal_seqs)
|
|
forward_func = self.create_pop_batch_forward(n)
|
|
|
|
def debug_forward(inputs):
|
|
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
|
|
|
|
return debug_forward
|
|
|
|
def ask_batch_forward(self, nodes, connections):
|
|
n = nodes.shape[0]
|
|
ts = self.create_topological_sort(n)
|
|
cal_seqs = ts(nodes, connections)
|
|
forward_func = self.create_batch_forward(n)
|
|
|
|
def debug_forward(inputs):
|
|
return forward_func(inputs, cal_seqs, nodes, connections)
|
|
|
|
return debug_forward
|
|
|
|
def compile_batch_unflatten_connections(self, n, c):
|
|
|
|
s = time.time()
|
|
|
|
func = unflatten_connections
|
|
func = vmap(func)
|
|
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
|
pop_connections_lower = np.zeros((self.pop_size, c, 4))
|
|
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
|
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
|
|
|
self.compile_time += time.time() - s
|
|
|
|
def create_batch_unflatten_connections(self, n, c):
|
|
key = ('batch_unflatten_connections', n, c)
|
|
if key not in self.compiled_function:
|
|
self.compile_batch_unflatten_connections(n, c)
|
|
|
|
return self.compiled_function[key]
|