create function_factory.py, use to manage functions

This commit is contained in:
wls2002
2023-05-09 01:49:43 +08:00
parent ee6bb01eff
commit f63a0c447b
7 changed files with 231 additions and 18 deletions

View File

@@ -0,0 +1,215 @@
"""
Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key
from .genome.genome import initialize_genomes
from .genome.mutate import mutate
from .genome.distance import distance
from .genome.crossover import crossover
class FunctionFactory:
def __init__(self, config, debug=False):
self.config = config
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.load_config_vals(config)
self.precompile()
pass
def load_config_vals(self, config):
self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
self.num_inputs = config.basic.num_inputs
self.num_outputs = config.basic.num_outputs
self.input_idx = np.arange(self.num_inputs)
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
bias = config.neat.gene.bias
self.bias_mean = bias.init_mean
self.bias_std = bias.init_stdev
self.bias_mutate_strength = bias.mutate_power
self.bias_mutate_rate = bias.mutate_rate
self.bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
self.response_mean = response.init_mean
self.response_std = response.init_stdev
self.response_mutate_strength = response.mutate_power
self.response_mutate_rate = response.mutate_rate
self.response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
self.weight_mean = weight.init_mean
self.weight_std = weight.init_stdev
self.weight_mutate_strength = weight.mutate_power
self.weight_mutate_rate = weight.mutate_rate
self.weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
self.act_default = act_name2key[activation.default]
self.act_list = np.array([act_name2key[name] for name in activation.options])
self.act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
self.agg_default = agg_name2key[aggregation.default]
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
self.agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
self.enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
self.add_node_rate = genome.node_add_prob
self.delete_node_rate = genome.node_delete_prob
self.add_connection_rate = genome.conn_add_prob
self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self):
func = partial(
initialize_genomes,
pop_size=self.pop_size,
N=self.init_N,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
default_response=self.response_mean,
default_act=self.act_default,
default_agg=self.agg_default,
default_weight=self.weight_mean
)
if self.debug:
return lambda *args: func(*args)
else:
return func
def precompile(self):
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
n = self.init_N
print("start precompile")
for _ in range(self.precompile_times):
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
n = int(self.expand_coe * n)
print("end precompile")
def create_mutate_with_args(self):
func = partial(
mutate,
input_idx=self.input_idx,
output_idx=self.output_idx,
bias_mean=self.bias_mean,
bias_std=self.bias_std,
bias_mutate_strength=self.bias_mutate_strength,
bias_mutate_rate=self.bias_mutate_rate,
bias_replace_rate=self.bias_replace_rate,
response_mean=self.response_mean,
response_std=self.response_std,
response_mutate_strength=self.response_mutate_strength,
response_mutate_rate=self.response_mutate_rate,
response_replace_rate=self.response_replace_rate,
weight_mean=self.weight_mean,
weight_std=self.weight_std,
weight_mutate_strength=self.weight_mutate_strength,
weight_mutate_rate=self.weight_mutate_rate,
weight_replace_rate=self.weight_replace_rate,
act_default=self.act_default,
act_list=self.act_list,
act_replace_rate=self.act_replace_rate,
agg_default=self.agg_default,
agg_list=self.agg_list,
agg_replace_rate=self.agg_replace_rate,
enabled_reverse_rate=self.enabled_reverse_rate,
add_node_rate=self.add_node_rate,
delete_node_rate=self.delete_node_rate,
add_connection_rate=self.add_connection_rate,
delete_connection_rate=self.delete_connection_rate,
single_structure_mutate=self.single_structure_mutate
)
self.mutate_with_args = func
def compile_mutate(self, n):
func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n)] = batched_mutate_func
def create_mutate(self, n):
key = ('mutate', n)
if key not in self.compiled_function:
self.compile_mutate(n)
return self.compiled_function[key]
def create_distance_with_args(self):
func = partial(
distance,
disjoint_coe=self.disjoint_coe,
compatibility_coe=self.compatibility_coe
)
self.distance_with_args = func
def compile_distance(self, n):
func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((2, n, n))
o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((2, n, n))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower,
o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n)] = o2o_distance
self.compiled_function[('o2m_distance', n)] = o2m_distance
def create_distance(self, n):
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
if key1 not in self.compiled_function:
self.compile_distance(n)
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, 2, n, n))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n)] = func
def create_crossover(self, n):
key = ('crossover', n)
if key not in self.compiled_function:
self.compile_crossover(n)
return self.compiled_function[key]

View File

@@ -3,3 +3,5 @@ from .distance import create_distance_function
from .mutate import create_mutate_function from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function
from .crossover import create_crossover_function from .crossover import create_crossover_function
from .activations import act_name2key
from .aggregations import agg_name2key

View File

@@ -33,9 +33,6 @@ def create_distance_function(N, config, type: str, debug: bool = False):
else: else:
return res_func return res_func
# return lambda nodes1, connections1, nodes2, connections2: \
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m': elif type == 'o2m':
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
pop_size = config.neat.population.pop_size pop_size = config.neat.population.pop_size

View File

@@ -45,7 +45,8 @@ def create_initialize_function(config):
def initialize_genomes(pop_size: int, def initialize_genomes(pop_size: int,
N: int, N: int,
num_inputs: int, num_outputs: int, num_inputs: int,
num_outputs: int,
default_bias: float = 0.0, default_bias: float = 0.0,
default_response: float = 1.0, default_response: float = 1.0,
default_act: int = 0, default_act: int = 0,

View File

@@ -113,13 +113,11 @@ def mutate(rand_key: Array,
new_node_key: int, new_node_key: int,
input_idx: Array, input_idx: Array,
output_idx: Array, output_idx: Array,
bias_default: float = 0,
bias_mean: float = 0, bias_mean: float = 0,
bias_std: float = 1, bias_std: float = 1,
bias_mutate_strength: float = 0.5, bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7, bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1, bias_replace_rate: float = 0.1,
response_default: float = 1,
response_mean: float = 1., response_mean: float = 1.,
response_std: float = 0., response_std: float = 0.,
response_mutate_strength: float = 0., response_mutate_strength: float = 0.,
@@ -147,8 +145,6 @@ def mutate(rand_key: Array,
:param input_idx: :param input_idx:
:param agg_default: :param agg_default:
:param act_default: :param act_default:
:param response_default:
:param bias_default:
:param rand_key: :param rand_key:
:param nodes: (N, 5) :param nodes: (N, 5)
:param connections: (2, N, N) :param connections: (2, N, N)
@@ -186,7 +182,7 @@ def mutate(rand_key: Array,
return n, c return n, c
def m_add_node(rk, n, c): def m_add_node(rk, n, c):
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default) return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default)
def m_delete_node(rk, n, c): def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx) return mutate_delete_node(rk, n, c, input_idx, output_idx)

View File

@@ -2,13 +2,13 @@ from typing import List, Union, Tuple, Callable
import time import time
import jax import jax
import jax.numpy as jnp
import numpy as np import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import expand, expand_single from .genome import expand, expand_single
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \ from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
create_distance_function, create_crossover_function create_distance_function, create_crossover_function
from .function_factory import FunctionFactory
class Pipeline: class Pipeline:
@@ -17,17 +17,18 @@ class Pipeline:
""" """
def __init__(self, config, seed=42): def __init__(self, config, seed=42):
self.generation_timestamp = time.time()
self.randkey = jax.random.PRNGKey(seed) self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed) np.random.seed(seed)
self.config = config self.config = config
self.function_factory = FunctionFactory(config)
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config) self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config) self.initialize_func = self.function_factory.create_initialize()
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.compile_functions(debug=True) self.compile_functions(debug=True)
@@ -36,6 +37,7 @@ class Pipeline:
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
def ask(self, batch: bool): def ask(self, batch: bool):
""" """
@@ -140,10 +142,9 @@ class Pipeline:
self.compile_functions(debug=True) self.compile_functions(debug=True)
def compile_functions(self, debug=False): def compile_functions(self, debug=False):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug) self.mutate_func = self.function_factory.create_mutate(self.N)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug) self.crossover_func = self.function_factory.create_crossover(self.N)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug) self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N)
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug)
def default_analysis(self, fitnesses): def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)

View File

@@ -2,8 +2,9 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 30, "init_maximum_nodes": 10,
"expands_coe": 2 "expands_coe": 2,
"pre_compile_times": 3
}, },
"neat": { "neat": {
"population": { "population": {