From f63a0c447b0cf87a72a5190365a890d744f84bed Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 9 May 2023 01:49:43 +0800 Subject: [PATCH] create function_factory.py, use to manage functions --- algorithms/neat/function_factory.py | 215 ++++++++++++++++++++++++++++ algorithms/neat/genome/__init__.py | 2 + algorithms/neat/genome/distance.py | 3 - algorithms/neat/genome/genome.py | 3 +- algorithms/neat/genome/mutate.py | 6 +- algorithms/neat/pipeline.py | 15 +- utils/default_config.json | 5 +- 7 files changed, 231 insertions(+), 18 deletions(-) create mode 100644 algorithms/neat/function_factory.py diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py new file mode 100644 index 0000000..20bf671 --- /dev/null +++ b/algorithms/neat/function_factory.py @@ -0,0 +1,215 @@ +""" +Lowers, compiles, and creates functions used in the NEAT pipeline. +""" +from functools import partial + +import numpy as np +from jax import jit, vmap + +from .genome import act_name2key, agg_name2key +from .genome.genome import initialize_genomes +from .genome.mutate import mutate +from .genome.distance import distance +from .genome.crossover import crossover + + +class FunctionFactory: + def __init__(self, config, debug=False): + self.config = config + self.debug = debug + + self.init_N = config.basic.init_maximum_nodes + self.expand_coe = config.basic.expands_coe + self.precompile_times = config.basic.pre_compile_times + self.compiled_function = {} + + self.load_config_vals(config) + self.precompile() + pass + + def load_config_vals(self, config): + self.pop_size = config.neat.population.pop_size + self.init_N = config.basic.init_maximum_nodes + + self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient + self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient + + self.num_inputs = config.basic.num_inputs + self.num_outputs = config.basic.num_outputs + self.input_idx = np.arange(self.num_inputs) + self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs) + + bias = config.neat.gene.bias + self.bias_mean = bias.init_mean + self.bias_std = bias.init_stdev + self.bias_mutate_strength = bias.mutate_power + self.bias_mutate_rate = bias.mutate_rate + self.bias_replace_rate = bias.replace_rate + + response = config.neat.gene.response + self.response_mean = response.init_mean + self.response_std = response.init_stdev + self.response_mutate_strength = response.mutate_power + self.response_mutate_rate = response.mutate_rate + self.response_replace_rate = response.replace_rate + + weight = config.neat.gene.weight + self.weight_mean = weight.init_mean + self.weight_std = weight.init_stdev + self.weight_mutate_strength = weight.mutate_power + self.weight_mutate_rate = weight.mutate_rate + self.weight_replace_rate = weight.replace_rate + + activation = config.neat.gene.activation + self.act_default = act_name2key[activation.default] + self.act_list = np.array([act_name2key[name] for name in activation.options]) + self.act_replace_rate = activation.mutate_rate + + aggregation = config.neat.gene.aggregation + self.agg_default = agg_name2key[aggregation.default] + self.agg_list = np.array([agg_name2key[name] for name in aggregation.options]) + self.agg_replace_rate = aggregation.mutate_rate + + enabled = config.neat.gene.enabled + self.enabled_reverse_rate = enabled.mutate_rate + + genome = config.neat.genome + self.add_node_rate = genome.node_add_prob + self.delete_node_rate = genome.node_delete_prob + self.add_connection_rate = genome.conn_add_prob + self.delete_connection_rate = genome.conn_delete_prob + self.single_structure_mutate = genome.single_structural_mutation + + def create_initialize(self): + func = partial( + initialize_genomes, + pop_size=self.pop_size, + N=self.init_N, + num_inputs=self.num_inputs, + num_outputs=self.num_outputs, + default_bias=self.bias_mean, + default_response=self.response_mean, + default_act=self.act_default, + default_agg=self.agg_default, + default_weight=self.weight_mean + ) + if self.debug: + return lambda *args: func(*args) + else: + return func + + def precompile(self): + self.create_mutate_with_args() + self.create_distance_with_args() + self.create_crossover_with_args() + n = self.init_N + print("start precompile") + for _ in range(self.precompile_times): + self.compile_mutate(n) + self.compile_distance(n) + self.compile_crossover(n) + n = int(self.expand_coe * n) + print("end precompile") + + def create_mutate_with_args(self): + func = partial( + mutate, + input_idx=self.input_idx, + output_idx=self.output_idx, + bias_mean=self.bias_mean, + bias_std=self.bias_std, + bias_mutate_strength=self.bias_mutate_strength, + bias_mutate_rate=self.bias_mutate_rate, + bias_replace_rate=self.bias_replace_rate, + response_mean=self.response_mean, + response_std=self.response_std, + response_mutate_strength=self.response_mutate_strength, + response_mutate_rate=self.response_mutate_rate, + response_replace_rate=self.response_replace_rate, + weight_mean=self.weight_mean, + weight_std=self.weight_std, + weight_mutate_strength=self.weight_mutate_strength, + weight_mutate_rate=self.weight_mutate_rate, + weight_replace_rate=self.weight_replace_rate, + act_default=self.act_default, + act_list=self.act_list, + act_replace_rate=self.act_replace_rate, + agg_default=self.agg_default, + agg_list=self.agg_list, + agg_replace_rate=self.agg_replace_rate, + enabled_reverse_rate=self.enabled_reverse_rate, + add_node_rate=self.add_node_rate, + delete_node_rate=self.delete_node_rate, + add_connection_rate=self.add_connection_rate, + delete_connection_rate=self.delete_connection_rate, + single_structure_mutate=self.single_structure_mutate + ) + self.mutate_with_args = func + + def compile_mutate(self, n): + func = self.mutate_with_args + rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) + nodes_lower = np.zeros((self.pop_size, n, 5)) + connections_lower = np.zeros((self.pop_size, 2, n, n)) + new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32) + batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower, + connections_lower, new_node_key_lower).compile() + self.compiled_function[('mutate', n)] = batched_mutate_func + + def create_mutate(self, n): + key = ('mutate', n) + if key not in self.compiled_function: + self.compile_mutate(n) + return self.compiled_function[key] + + def create_distance_with_args(self): + func = partial( + distance, + disjoint_coe=self.disjoint_coe, + compatibility_coe=self.compatibility_coe + ) + self.distance_with_args = func + + def compile_distance(self, n): + func = self.distance_with_args + o2o_nodes1_lower = np.zeros((n, 5)) + o2o_connections1_lower = np.zeros((2, n, n)) + o2o_nodes2_lower = np.zeros((n, 5)) + o2o_connections2_lower = np.zeros((2, n, n)) + o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, + o2o_nodes2_lower, o2o_connections2_lower).compile() + + o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) + o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n)) + o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, + o2m_nodes2_lower, + o2m_connections2_lower).compile() + + self.compiled_function[('o2o_distance', n)] = o2o_distance + self.compiled_function[('o2m_distance', n)] = o2m_distance + + def create_distance(self, n): + key1, key2 = ('o2o_distance', n), ('o2m_distance', n) + if key1 not in self.compiled_function: + self.compile_distance(n) + return self.compiled_function[key1], self.compiled_function[key2] + + def create_crossover_with_args(self): + self.crossover_with_args = crossover + + def compile_crossover(self, n): + func = self.crossover_with_args + randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) + nodes1_lower = np.zeros((self.pop_size, n, 5)) + connections1_lower = np.zeros((self.pop_size, 2, n, n)) + nodes2_lower = np.zeros((self.pop_size, n, 5)) + connections2_lower = np.zeros((self.pop_size, 2, n, n)) + func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() + self.compiled_function[('crossover', n)] = func + + def create_crossover(self, n): + key = ('crossover', n) + if key not in self.compiled_function: + self.compile_crossover(n) + return self.compiled_function[key] diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index a588d97..fdd306b 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -3,3 +3,5 @@ from .distance import create_distance_function from .mutate import create_mutate_function from .forward import create_forward_function from .crossover import create_crossover_function +from .activations import act_name2key +from .aggregations import agg_name2key diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 58c24df..4a78b3f 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -33,9 +33,6 @@ def create_distance_function(N, config, type: str, debug: bool = False): else: return res_func - # return lambda nodes1, connections1, nodes2, connections2: \ - # distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) - elif type == 'o2m': vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) pop_size = config.neat.population.pop_size diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 53dfcf1..591f022 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -45,7 +45,8 @@ def create_initialize_function(config): def initialize_genomes(pop_size: int, N: int, - num_inputs: int, num_outputs: int, + num_inputs: int, + num_outputs: int, default_bias: float = 0.0, default_response: float = 1.0, default_act: int = 0, diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 79d462a..af08637 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -113,13 +113,11 @@ def mutate(rand_key: Array, new_node_key: int, input_idx: Array, output_idx: Array, - bias_default: float = 0, bias_mean: float = 0, bias_std: float = 1, bias_mutate_strength: float = 0.5, bias_mutate_rate: float = 0.7, bias_replace_rate: float = 0.1, - response_default: float = 1, response_mean: float = 1., response_std: float = 0., response_mutate_strength: float = 0., @@ -147,8 +145,6 @@ def mutate(rand_key: Array, :param input_idx: :param agg_default: :param act_default: - :param response_default: - :param bias_default: :param rand_key: :param nodes: (N, 5) :param connections: (2, N, N) @@ -186,7 +182,7 @@ def mutate(rand_key: Array, return n, c def m_add_node(rk, n, c): - return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default) + return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default) def m_delete_node(rk, n, c): return mutate_delete_node(rk, n, c, input_idx, output_idx) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 8196470..3f2c2bc 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -2,13 +2,13 @@ from typing import List, Union, Tuple, Callable import time import jax -import jax.numpy as jnp import numpy as np from .species import SpeciesController from .genome import expand, expand_single from .genome import create_initialize_function, create_mutate_function, create_forward_function, \ create_distance_function, create_crossover_function +from .function_factory import FunctionFactory class Pipeline: @@ -17,17 +17,18 @@ class Pipeline: """ def __init__(self, config, seed=42): - self.generation_timestamp = time.time() self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) self.config = config + self.function_factory = FunctionFactory(config) + self.N = config.basic.init_maximum_nodes self.expand_coe = config.basic.expands_coe self.pop_size = config.neat.population.pop_size self.species_controller = SpeciesController(config) - self.initialize_func = create_initialize_function(config) + self.initialize_func = self.function_factory.create_initialize() self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.compile_functions(debug=True) @@ -36,6 +37,7 @@ class Pipeline: self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) self.best_fitness = float('-inf') + self.generation_timestamp = time.time() def ask(self, batch: bool): """ @@ -140,10 +142,9 @@ class Pipeline: self.compile_functions(debug=True) def compile_functions(self, debug=False): - self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug) - self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug) - self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug) - self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug) + self.mutate_func = self.function_factory.create_mutate(self.N) + self.crossover_func = self.function_factory.create_crossover(self.N) + self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N) def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) diff --git a/utils/default_config.json b/utils/default_config.json index f6e8506..0bcdb21 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,8 +2,9 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 30, - "expands_coe": 2 + "init_maximum_nodes": 10, + "expands_coe": 2, + "pre_compile_times": 3 }, "neat": { "population": {