create function_factory.py, use to manage functions
This commit is contained in:
215
algorithms/neat/function_factory.py
Normal file
215
algorithms/neat/function_factory.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Lowers, compiles, and creates functions used in the NEAT pipeline.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import act_name2key, agg_name2key
|
||||
from .genome.genome import initialize_genomes
|
||||
from .genome.mutate import mutate
|
||||
from .genome.distance import distance
|
||||
from .genome.crossover import crossover
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
def __init__(self, config, debug=False):
|
||||
self.config = config
|
||||
self.debug = debug
|
||||
|
||||
self.init_N = config.basic.init_maximum_nodes
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
|
||||
self.load_config_vals(config)
|
||||
self.precompile()
|
||||
pass
|
||||
|
||||
def load_config_vals(self, config):
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
self.init_N = config.basic.init_maximum_nodes
|
||||
|
||||
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
|
||||
self.num_inputs = config.basic.num_inputs
|
||||
self.num_outputs = config.basic.num_outputs
|
||||
self.input_idx = np.arange(self.num_inputs)
|
||||
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
|
||||
|
||||
bias = config.neat.gene.bias
|
||||
self.bias_mean = bias.init_mean
|
||||
self.bias_std = bias.init_stdev
|
||||
self.bias_mutate_strength = bias.mutate_power
|
||||
self.bias_mutate_rate = bias.mutate_rate
|
||||
self.bias_replace_rate = bias.replace_rate
|
||||
|
||||
response = config.neat.gene.response
|
||||
self.response_mean = response.init_mean
|
||||
self.response_std = response.init_stdev
|
||||
self.response_mutate_strength = response.mutate_power
|
||||
self.response_mutate_rate = response.mutate_rate
|
||||
self.response_replace_rate = response.replace_rate
|
||||
|
||||
weight = config.neat.gene.weight
|
||||
self.weight_mean = weight.init_mean
|
||||
self.weight_std = weight.init_stdev
|
||||
self.weight_mutate_strength = weight.mutate_power
|
||||
self.weight_mutate_rate = weight.mutate_rate
|
||||
self.weight_replace_rate = weight.replace_rate
|
||||
|
||||
activation = config.neat.gene.activation
|
||||
self.act_default = act_name2key[activation.default]
|
||||
self.act_list = np.array([act_name2key[name] for name in activation.options])
|
||||
self.act_replace_rate = activation.mutate_rate
|
||||
|
||||
aggregation = config.neat.gene.aggregation
|
||||
self.agg_default = agg_name2key[aggregation.default]
|
||||
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
||||
self.agg_replace_rate = aggregation.mutate_rate
|
||||
|
||||
enabled = config.neat.gene.enabled
|
||||
self.enabled_reverse_rate = enabled.mutate_rate
|
||||
|
||||
genome = config.neat.genome
|
||||
self.add_node_rate = genome.node_add_prob
|
||||
self.delete_node_rate = genome.node_delete_prob
|
||||
self.add_connection_rate = genome.conn_add_prob
|
||||
self.delete_connection_rate = genome.conn_delete_prob
|
||||
self.single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def create_initialize(self):
|
||||
func = partial(
|
||||
initialize_genomes,
|
||||
pop_size=self.pop_size,
|
||||
N=self.init_N,
|
||||
num_inputs=self.num_inputs,
|
||||
num_outputs=self.num_outputs,
|
||||
default_bias=self.bias_mean,
|
||||
default_response=self.response_mean,
|
||||
default_act=self.act_default,
|
||||
default_agg=self.agg_default,
|
||||
default_weight=self.weight_mean
|
||||
)
|
||||
if self.debug:
|
||||
return lambda *args: func(*args)
|
||||
else:
|
||||
return func
|
||||
|
||||
def precompile(self):
|
||||
self.create_mutate_with_args()
|
||||
self.create_distance_with_args()
|
||||
self.create_crossover_with_args()
|
||||
n = self.init_N
|
||||
print("start precompile")
|
||||
for _ in range(self.precompile_times):
|
||||
self.compile_mutate(n)
|
||||
self.compile_distance(n)
|
||||
self.compile_crossover(n)
|
||||
n = int(self.expand_coe * n)
|
||||
print("end precompile")
|
||||
|
||||
def create_mutate_with_args(self):
|
||||
func = partial(
|
||||
mutate,
|
||||
input_idx=self.input_idx,
|
||||
output_idx=self.output_idx,
|
||||
bias_mean=self.bias_mean,
|
||||
bias_std=self.bias_std,
|
||||
bias_mutate_strength=self.bias_mutate_strength,
|
||||
bias_mutate_rate=self.bias_mutate_rate,
|
||||
bias_replace_rate=self.bias_replace_rate,
|
||||
response_mean=self.response_mean,
|
||||
response_std=self.response_std,
|
||||
response_mutate_strength=self.response_mutate_strength,
|
||||
response_mutate_rate=self.response_mutate_rate,
|
||||
response_replace_rate=self.response_replace_rate,
|
||||
weight_mean=self.weight_mean,
|
||||
weight_std=self.weight_std,
|
||||
weight_mutate_strength=self.weight_mutate_strength,
|
||||
weight_mutate_rate=self.weight_mutate_rate,
|
||||
weight_replace_rate=self.weight_replace_rate,
|
||||
act_default=self.act_default,
|
||||
act_list=self.act_list,
|
||||
act_replace_rate=self.act_replace_rate,
|
||||
agg_default=self.agg_default,
|
||||
agg_list=self.agg_list,
|
||||
agg_replace_rate=self.agg_replace_rate,
|
||||
enabled_reverse_rate=self.enabled_reverse_rate,
|
||||
add_node_rate=self.add_node_rate,
|
||||
delete_node_rate=self.delete_node_rate,
|
||||
add_connection_rate=self.add_connection_rate,
|
||||
delete_connection_rate=self.delete_connection_rate,
|
||||
single_structure_mutate=self.single_structure_mutate
|
||||
)
|
||||
self.mutate_with_args = func
|
||||
|
||||
def compile_mutate(self, n):
|
||||
func = self.mutate_with_args
|
||||
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
|
||||
connections_lower, new_node_key_lower).compile()
|
||||
self.compiled_function[('mutate', n)] = batched_mutate_func
|
||||
|
||||
def create_mutate(self, n):
|
||||
key = ('mutate', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_mutate(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_distance_with_args(self):
|
||||
func = partial(
|
||||
distance,
|
||||
disjoint_coe=self.disjoint_coe,
|
||||
compatibility_coe=self.compatibility_coe
|
||||
)
|
||||
self.distance_with_args = func
|
||||
|
||||
def compile_distance(self, n):
|
||||
func = self.distance_with_args
|
||||
o2o_nodes1_lower = np.zeros((n, 5))
|
||||
o2o_connections1_lower = np.zeros((2, n, n))
|
||||
o2o_nodes2_lower = np.zeros((n, 5))
|
||||
o2o_connections2_lower = np.zeros((2, n, n))
|
||||
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2o_nodes2_lower, o2o_connections2_lower).compile()
|
||||
|
||||
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2m_nodes2_lower,
|
||||
o2m_connections2_lower).compile()
|
||||
|
||||
self.compiled_function[('o2o_distance', n)] = o2o_distance
|
||||
self.compiled_function[('o2m_distance', n)] = o2m_distance
|
||||
|
||||
def create_distance(self, n):
|
||||
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
||||
if key1 not in self.compiled_function:
|
||||
self.compile_distance(n)
|
||||
return self.compiled_function[key1], self.compiled_function[key2]
|
||||
|
||||
def create_crossover_with_args(self):
|
||||
self.crossover_with_args = crossover
|
||||
|
||||
def compile_crossover(self, n):
|
||||
func = self.crossover_with_args
|
||||
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes1_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections1_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
self.compiled_function[('crossover', n)] = func
|
||||
|
||||
def create_crossover(self, n):
|
||||
key = ('crossover', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_crossover(n)
|
||||
return self.compiled_function[key]
|
||||
@@ -3,3 +3,5 @@ from .distance import create_distance_function
|
||||
from .mutate import create_mutate_function
|
||||
from .forward import create_forward_function
|
||||
from .crossover import create_crossover_function
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
|
||||
@@ -33,9 +33,6 @@ def create_distance_function(N, config, type: str, debug: bool = False):
|
||||
else:
|
||||
return res_func
|
||||
|
||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
||||
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
elif type == 'o2m':
|
||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
|
||||
@@ -45,7 +45,8 @@ def create_initialize_function(config):
|
||||
|
||||
def initialize_genomes(pop_size: int,
|
||||
N: int,
|
||||
num_inputs: int, num_outputs: int,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
default_bias: float = 0.0,
|
||||
default_response: float = 1.0,
|
||||
default_act: int = 0,
|
||||
|
||||
@@ -113,13 +113,11 @@ def mutate(rand_key: Array,
|
||||
new_node_key: int,
|
||||
input_idx: Array,
|
||||
output_idx: Array,
|
||||
bias_default: float = 0,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
bias_mutate_strength: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
response_default: float = 1,
|
||||
response_mean: float = 1.,
|
||||
response_std: float = 0.,
|
||||
response_mutate_strength: float = 0.,
|
||||
@@ -147,8 +145,6 @@ def mutate(rand_key: Array,
|
||||
:param input_idx:
|
||||
:param agg_default:
|
||||
:param act_default:
|
||||
:param response_default:
|
||||
:param bias_default:
|
||||
:param rand_key:
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
@@ -186,7 +182,7 @@ def mutate(rand_key: Array,
|
||||
return n, c
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
@@ -2,13 +2,13 @@ from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
|
||||
create_distance_function, create_crossover_function
|
||||
from .function_factory import FunctionFactory
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -17,17 +17,18 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.generation_timestamp = time.time()
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config
|
||||
self.function_factory = FunctionFactory(config)
|
||||
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = create_initialize_function(config)
|
||||
self.initialize_func = self.function_factory.create_initialize()
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
|
||||
self.compile_functions(debug=True)
|
||||
@@ -36,6 +37,7 @@ class Pipeline:
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
def ask(self, batch: bool):
|
||||
"""
|
||||
@@ -140,10 +142,9 @@ class Pipeline:
|
||||
self.compile_functions(debug=True)
|
||||
|
||||
def compile_functions(self, debug=False):
|
||||
self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug)
|
||||
self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug)
|
||||
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug)
|
||||
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug)
|
||||
self.mutate_func = self.function_factory.create_mutate(self.N)
|
||||
self.crossover_func = self.function_factory.create_crossover(self.N)
|
||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N)
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
"basic": {
|
||||
"num_inputs": 2,
|
||||
"num_outputs": 1,
|
||||
"init_maximum_nodes": 30,
|
||||
"expands_coe": 2
|
||||
"init_maximum_nodes": 10,
|
||||
"expands_coe": 2,
|
||||
"pre_compile_times": 3
|
||||
},
|
||||
"neat": {
|
||||
"population": {
|
||||
|
||||
Reference in New Issue
Block a user