finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -1,323 +1,108 @@
|
||||
"""
|
||||
Lowers, compiles, and creates functions used in the NEAT pipeline.
|
||||
"""
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import act_name2key, agg_name2key, initialize_genomes
|
||||
from .genome import topological_sort, forward_single, unflatten_connections
|
||||
from .population import create_next_generation_then_speciate
|
||||
from .genome.forward import create_forward
|
||||
from .genome.utils import unflatten_connections
|
||||
from .genome.graph import topological_sort
|
||||
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
"""
|
||||
Creates and compiles functions used in the NEAT pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.func_dict = {}
|
||||
self.function_info = {}
|
||||
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
self.compile_time = 0
|
||||
# (inputs_nums, ) -> (outputs_nums, )
|
||||
forward = create_forward(config) # input size (inputs_nums, )
|
||||
|
||||
self.load_config_vals(config)
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
self.create_topological_sort_with_args()
|
||||
self.create_single_forward_with_args()
|
||||
self.create_update_speciate_with_args()
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
def load_config_vals(self, config):
|
||||
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
self.problem_batch = config.basic.problem_batch
|
||||
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
self.function_info = {
|
||||
"pop_unflatten_connections": {
|
||||
'func': vmap(unflatten_connections),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
"pop_topological_sort": {
|
||||
'func': vmap(topological_sort),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
|
||||
]
|
||||
},
|
||||
|
||||
self.num_inputs = config.basic.num_inputs
|
||||
self.num_outputs = config.basic.num_outputs
|
||||
self.input_idx = np.arange(self.num_inputs)
|
||||
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
|
||||
"batch_forward": {
|
||||
'func': batch_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('N', ), 'type': np.int32},
|
||||
{'shape': ('N', 5), 'type': np.float32},
|
||||
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
bias = config.neat.gene.bias
|
||||
self.bias_mean = bias.init_mean
|
||||
self.bias_std = bias.init_stdev
|
||||
self.bias_mutate_strength = bias.mutate_power
|
||||
self.bias_mutate_rate = bias.mutate_rate
|
||||
self.bias_replace_rate = bias.replace_rate
|
||||
"pop_batch_forward": {
|
||||
'func': pop_batch_forward,
|
||||
'lowers': [
|
||||
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
response = config.neat.gene.response
|
||||
self.response_mean = response.init_mean
|
||||
self.response_std = response.init_stdev
|
||||
self.response_mutate_strength = response.mutate_power
|
||||
self.response_mutate_rate = response.mutate_rate
|
||||
self.response_replace_rate = response.replace_rate
|
||||
|
||||
weight = config.neat.gene.weight
|
||||
self.weight_mean = weight.init_mean
|
||||
self.weight_std = weight.init_stdev
|
||||
self.weight_mutate_strength = weight.mutate_power
|
||||
self.weight_mutate_rate = weight.mutate_rate
|
||||
self.weight_replace_rate = weight.replace_rate
|
||||
|
||||
activation = config.neat.gene.activation
|
||||
self.act_default = act_name2key[activation.default]
|
||||
self.act_list = np.array([act_name2key[name] for name in activation.options])
|
||||
self.act_replace_rate = activation.mutate_rate
|
||||
|
||||
aggregation = config.neat.gene.aggregation
|
||||
self.agg_default = agg_name2key[aggregation.default]
|
||||
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
||||
self.agg_replace_rate = aggregation.mutate_rate
|
||||
|
||||
enabled = config.neat.gene.enabled
|
||||
self.enabled_reverse_rate = enabled.mutate_rate
|
||||
|
||||
genome = config.neat.genome
|
||||
self.add_node_rate = genome.node_add_prob
|
||||
self.delete_node_rate = genome.node_delete_prob
|
||||
self.add_connection_rate = genome.conn_add_prob
|
||||
self.delete_connection_rate = genome.conn_delete_prob
|
||||
self.single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def create_initialize(self, N, C):
|
||||
func = partial(
|
||||
initialize_genomes,
|
||||
pop_size=self.pop_size,
|
||||
N=N,
|
||||
C=C,
|
||||
num_inputs=self.num_inputs,
|
||||
num_outputs=self.num_outputs,
|
||||
default_bias=self.bias_mean,
|
||||
default_response=self.response_mean,
|
||||
default_act=self.act_default,
|
||||
default_agg=self.agg_default,
|
||||
default_weight=self.weight_mean
|
||||
)
|
||||
return func
|
||||
|
||||
def create_update_speciate_with_args(self):
|
||||
species_kwargs = {
|
||||
"disjoint_coe": self.disjoint_coe,
|
||||
"compatibility_coe": self.compatibility_coe,
|
||||
"compatibility_threshold": self.compatibility_threshold
|
||||
'common_forward': {
|
||||
'func': common_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mutate_kwargs = {
|
||||
"input_idx": self.input_idx,
|
||||
"output_idx": self.output_idx,
|
||||
"bias_mean": self.bias_mean,
|
||||
"bias_std": self.bias_std,
|
||||
"bias_mutate_strength": self.bias_mutate_strength,
|
||||
"bias_mutate_rate": self.bias_mutate_rate,
|
||||
"bias_replace_rate": self.bias_replace_rate,
|
||||
"response_mean": self.response_mean,
|
||||
"response_std": self.response_std,
|
||||
"response_mutate_strength": self.response_mutate_strength,
|
||||
"response_mutate_rate": self.response_mutate_rate,
|
||||
"response_replace_rate": self.response_replace_rate,
|
||||
"weight_mean": self.weight_mean,
|
||||
"weight_std": self.weight_std,
|
||||
"weight_mutate_strength": self.weight_mutate_strength,
|
||||
"weight_mutate_rate": self.weight_mutate_rate,
|
||||
"weight_replace_rate": self.weight_replace_rate,
|
||||
"act_default": self.act_default,
|
||||
"act_list": self.act_list,
|
||||
"act_replace_rate": self.act_replace_rate,
|
||||
"agg_default": self.agg_default,
|
||||
"agg_list": self.agg_list,
|
||||
"agg_replace_rate": self.agg_replace_rate,
|
||||
"enabled_reverse_rate": self.enabled_reverse_rate,
|
||||
"add_node_rate": self.add_node_rate,
|
||||
"delete_node_rate": self.delete_node_rate,
|
||||
"add_connection_rate": self.add_connection_rate,
|
||||
"delete_connection_rate": self.delete_connection_rate,
|
||||
}
|
||||
|
||||
self.update_speciate_with_args = partial(
|
||||
create_next_generation_then_speciate,
|
||||
species_kwargs=species_kwargs,
|
||||
mutate_kwargs=mutate_kwargs
|
||||
)
|
||||
def get(self, name, symbols):
|
||||
if (name, hash_symbols(symbols)) not in self.func_dict:
|
||||
self.compile(name, symbols)
|
||||
return self.func_dict[name, hash_symbols(symbols)]
|
||||
|
||||
def create_update_speciate(self, N, C, S):
|
||||
key = ("update_speciate", N, C, S)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_update_speciate(N, C, S)
|
||||
return self.compiled_function[key]
|
||||
def compile(self, name, symbols):
|
||||
# prepare function prototype
|
||||
func = self.function_info[name]['func']
|
||||
|
||||
def compile_update_speciate(self, N, C, S):
|
||||
s = time.time()
|
||||
# prepare lower operands
|
||||
lowers_operands = []
|
||||
for lower in self.function_info[name]['lowers']:
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
shape[i] = symbols[s]
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
|
||||
func = self.update_speciate_with_args
|
||||
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
||||
pop_cons_lower = np.zeros((self.pop_size, C, 4))
|
||||
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
|
||||
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
|
||||
pre_spe_center_cons_lower = np.zeros((S, C, 4))
|
||||
species_keys = np.zeros((S,), dtype=np.int32)
|
||||
new_species_keys_lower = 0
|
||||
compiled_func = jit(func).lower(
|
||||
randkey_lower,
|
||||
pop_nodes_lower,
|
||||
pop_cons_lower,
|
||||
winner_part_lower,
|
||||
loser_part_lower,
|
||||
elite_mask_lower,
|
||||
new_node_keys_start_lower,
|
||||
pre_spe_center_nodes_lower,
|
||||
pre_spe_center_cons_lower,
|
||||
species_keys,
|
||||
new_species_keys_lower,
|
||||
).compile()
|
||||
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
||||
# compile
|
||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_with_args(self):
|
||||
self.topological_sort_with_args = topological_sort
|
||||
|
||||
def compile_topological_sort(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort(self, n):
|
||||
key = ('topological_sort', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_topological_sort(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_topological_sort_batch(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
func = vmap(func)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort_batch', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_batch(self, n):
|
||||
key = ('topological_sort_batch', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_topological_sort_batch(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_single_forward_with_args(self):
|
||||
func = partial(
|
||||
forward_single,
|
||||
input_idx=self.input_idx,
|
||||
output_idx=self.output_idx
|
||||
)
|
||||
self.single_forward_with_args = func
|
||||
|
||||
|
||||
def compile_batch_forward(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None))
|
||||
|
||||
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
||||
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_forward(self, n):
|
||||
key = ('batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_forward(n)
|
||||
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_pop_batch_forward(self, n):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
||||
|
||||
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
||||
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('pop_batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_pop_batch_forward(self, n):
|
||||
key = ('pop_batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_pop_batch_forward(n)
|
||||
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||
ts = self.create_topological_sort_batch(n)
|
||||
|
||||
# for connections with enabled is false, set weight to 0)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||
# print(pop_cal_seqs)
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
|
||||
|
||||
return debug_forward
|
||||
|
||||
def ask_batch_forward(self, nodes, connections):
|
||||
n = nodes.shape[0]
|
||||
ts = self.create_topological_sort(n)
|
||||
cal_seqs = ts(nodes, connections)
|
||||
forward_func = self.create_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||
|
||||
return debug_forward
|
||||
|
||||
def compile_batch_unflatten_connections(self, n, c):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = unflatten_connections
|
||||
func = vmap(func)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
pop_connections_lower = np.zeros((self.pop_size, c, 4))
|
||||
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_unflatten_connections(self, n, c):
|
||||
key = ('batch_unflatten_connections', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_unflatten_connections(n, c)
|
||||
|
||||
return self.compiled_function[key]
|
||||
# save for reuse
|
||||
self.func_dict[name, hash_symbols(symbols)] = compiled_func
|
||||
|
||||
Reference in New Issue
Block a user