From 2b79f2c903f02bc2ca8ca91e3736faf64bec35d1 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 14 May 2023 15:27:17 +0800 Subject: [PATCH] prepare for experiment --- algorithms/neat/function_factory.py | 77 ++++++++++------------ algorithms/neat/genome/activations.py | 5 +- algorithms/neat/genome/mutate.py | 16 +++++ algorithms/neat/genome/utils.py | 6 ++ algorithms/neat/pipeline.py | 21 ++++-- examples/enhane_xor.py | 44 +++++++++++++ examples/final_design_experiement.py | 56 ++++++++++++++++ examples/xor.py | 18 +++-- problems/function_fitting/__init__.py | 3 +- problems/function_fitting/enhance_logic.py | 54 +++++++++++++++ utils/default_config.json | 14 ++-- 11 files changed, 252 insertions(+), 62 deletions(-) create mode 100644 examples/enhane_xor.py create mode 100644 examples/final_design_experiement.py create mode 100644 problems/function_fitting/enhance_logic.py diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index 5a8c257..8c05bb3 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -19,7 +19,7 @@ class FunctionFactory: self.expand_coe = config.basic.expands_coe self.precompile_times = config.basic.pre_compile_times self.compiled_function = {} - self.time_cost = {} + self.compile_time = 0 self.load_config_vals(config) @@ -150,6 +150,8 @@ class FunctionFactory: return self.compiled_function[key] def compile_update_speciate(self, N, C, S): + s = time.time() + func = self.update_speciate_with_args randkey_lower = np.zeros((2,), dtype=np.uint32) pop_nodes_lower = np.zeros((self.pop_size, N, 5)) @@ -177,16 +179,22 @@ class FunctionFactory: ).compile() self.compiled_function[("update_speciate", N, C, S)] = compiled_func + self.compile_time += time.time() - s + def create_topological_sort_with_args(self): self.topological_sort_with_args = topological_sort def compile_topological_sort(self, n): + s = time.time() + func = self.topological_sort_with_args nodes_lower = np.zeros((n, 5)) connections_lower = np.zeros((2, n, n)) func = jit(func).lower(nodes_lower, connections_lower).compile() self.compiled_function[('topological_sort', n)] = func + self.compile_time += time.time() - s + def create_topological_sort(self, n): key = ('topological_sort', n) if key not in self.compiled_function: @@ -194,6 +202,8 @@ class FunctionFactory: return self.compiled_function[key] def compile_topological_sort_batch(self, n): + s = time.time() + func = self.topological_sort_with_args func = vmap(func) nodes_lower = np.zeros((self.pop_size, n, 5)) @@ -201,6 +211,8 @@ class FunctionFactory: func = jit(func).lower(nodes_lower, connections_lower).compile() self.compiled_function[('topological_sort_batch', n)] = func + self.compile_time += time.time() - s + def create_topological_sort_batch(self, n): key = ('topological_sort_batch', n) if key not in self.compiled_function: @@ -215,32 +227,10 @@ class FunctionFactory: ) self.single_forward_with_args = func - def compile_single_forward(self, n): - """ - single input for a genome - :param n: - :return: - """ - func = self.single_forward_with_args - inputs_lower = np.zeros((self.num_inputs,)) - cal_seqs_lower = np.zeros((n,), dtype=np.int32) - nodes_lower = np.zeros((n, 5)) - connections_lower = np.zeros((2, n, n)) - func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() - self.compiled_function[('single_forward', n)] = func - - def compile_pop_forward(self, n): - func = self.single_forward_with_args - func = vmap(func, in_axes=(None, 0, 0, 0)) - - inputs_lower = np.zeros((self.num_inputs,)) - cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32) - nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, 2, n, n)) - func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() - self.compiled_function[('pop_forward', n)] = func def compile_batch_forward(self, n): + s = time.time() + func = self.single_forward_with_args func = vmap(func, in_axes=(0, None, None, None)) @@ -251,19 +241,19 @@ class FunctionFactory: func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() self.compiled_function[('batch_forward', n)] = func + self.compile_time += time.time() - s + def create_batch_forward(self, n): key = ('batch_forward', n) if key not in self.compiled_function: self.compile_batch_forward(n) - if self.debug: - def debug_batch_forward(*args): - return self.compiled_function[key](*args).block_until_ready() - return debug_batch_forward - else: - return self.compiled_function[key] + return self.compiled_function[key] def compile_pop_batch_forward(self, n): + + s = time.time() + func = self.single_forward_with_args func = vmap(func, in_axes=(0, None, None, None)) # batch_forward func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward @@ -276,25 +266,24 @@ class FunctionFactory: func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() self.compiled_function[('pop_batch_forward', n)] = func + self.compile_time += time.time() - s + def create_pop_batch_forward(self, n): key = ('pop_batch_forward', n) if key not in self.compiled_function: self.compile_pop_batch_forward(n) - if self.debug: - def debug_pop_batch_forward(*args): - return self.compiled_function[key](*args).block_until_ready() - return debug_pop_batch_forward - else: - return self.compiled_function[key] + return self.compiled_function[key] def ask_pop_batch_forward(self, pop_nodes, pop_cons): n, c = pop_nodes.shape[1], pop_cons.shape[1] batch_unflatten_func = self.create_batch_unflatten_connections(n, c) pop_cons = batch_unflatten_func(pop_nodes, pop_cons) ts = self.create_topological_sort_batch(n) - pop_cal_seqs = ts(pop_nodes, pop_cons) + # for connections with enabled is false, set weight to 0) + pop_cal_seqs = ts(pop_nodes, pop_cons) + # print(pop_cal_seqs) forward_func = self.create_pop_batch_forward(n) def debug_forward(inputs): @@ -314,6 +303,9 @@ class FunctionFactory: return debug_forward def compile_batch_unflatten_connections(self, n, c): + + s = time.time() + func = unflatten_connections func = vmap(func) pop_nodes_lower = np.zeros((self.pop_size, n, 5)) @@ -321,14 +313,11 @@ class FunctionFactory: func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile() self.compiled_function[('batch_unflatten_connections', n, c)] = func + self.compile_time += time.time() - s + def create_batch_unflatten_connections(self, n, c): key = ('batch_unflatten_connections', n, c) if key not in self.compiled_function: self.compile_batch_unflatten_connections(n, c) - if self.debug: - def debug_batch_unflatten_connections(*args): - return self.compiled_function[key](*args).block_until_ready() - return debug_batch_unflatten_connections - else: - return self.compiled_function[key] + return self.compiled_function[key] diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py index db30a78..eaf048b 100644 --- a/algorithms/neat/genome/activations.py +++ b/algorithms/neat/genome/activations.py @@ -133,5 +133,8 @@ act_name2key = { def act(idx, z): idx = jnp.asarray(idx, dtype=jnp.int32) # change idx from float to int - return jax.lax.switch(idx, ACT_TOTAL_LIST, z) + res = jax.lax.switch(idx, ACT_TOTAL_LIST, z) + return jnp.where(jnp.isnan(res), jnp.nan, res) + + # return jax.lax.switch(idx, ACT_TOTAL_LIST, z) diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 88c56ce..0638377 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -88,6 +88,12 @@ def mutate(rand_key: Array, def m_add_connection(rk, n, c): return mutate_add_connection(rk, n, c, input_idx, output_idx) + def m_delete_node(rk, n, c): + return mutate_delete_node(rk, n, c, input_idx, output_idx) + + def m_delete_connection(rk, n, c): + return mutate_delete_connection(rk, n, c) + r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) # mutate add node @@ -100,6 +106,16 @@ def mutate(rand_key: Array, nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) + # mutate delete node + aux_nodes, aux_connections = m_delete_node(r2, nodes, connections) + nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes) + connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections) + + # mutate delete connection + aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections) + nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes) + connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections) + nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, bias_replace_rate, response_mean, response_std, response_mutate_strength, response_mutate_rate, response_replace_rate, diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 8e662e7..9dd741d 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -14,6 +14,8 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan) def unflatten_connections(nodes, cons): """ transform the (C, 4) connections to (2, N, N) + this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled + connections to nan, that means we dont consider such connection when forward; :param cons: :param nodes: :return: @@ -29,6 +31,10 @@ def unflatten_connections(nodes, cons): # however, it will do nothing set values in an array res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) + + # (2, N, N), (2, N, N), (2, N, N) + # res = jnp.where(res[1, :, :] == 0, jnp.nan, res) + return res diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 27c0083..c3326d3 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -16,9 +16,9 @@ class Pipeline: Neat algorithm pipeline. """ - def __init__(self, config, seed=42): + def __init__(self, config, function_factory, seed=42): self.time_dict = {} - self.function_factory = FunctionFactory(config) + self.function_factory = function_factory self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) @@ -31,18 +31,21 @@ class Pipeline: self.pop_size = config.neat.population.pop_size self.species_controller = SpeciesController(config) - self.initialize_func = self.function_factory.create_initialize() + self.initialize_func = self.function_factory.create_initialize(self.N, self.C) self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func() self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S) self.generation = 0 + self.generation_time_list = [] self.species_controller.init_speciate(self.pop_nodes, self.pop_cons) self.best_fitness = float('-inf') self.best_genome = None self.generation_timestamp = time.time() + self.evaluate_time = 0 + def ask(self): """ Create a forward function for the population. @@ -66,7 +69,9 @@ class Pipeline: new_node_keys, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start) - idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys]) + + self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \ + jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys]) self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation) @@ -75,7 +80,12 @@ class Pipeline: def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config.neat.population.generation_limit): forward_func = self.ask() + + tic = time.time() fitnesses = fitness_func(forward_func) + self.evaluate_time += time.time() - tic + + assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" if analysis is not None: if analysis == "default": @@ -104,6 +114,7 @@ class Pipeline: max_node_size = np.max(pop_node_sizes) if max_node_size >= self.N: self.N = int(self.N * self.expand_coe) + # self.C = int(self.C * self.expand_coe) print(f"node expand to {self.N}!") self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) @@ -116,6 +127,7 @@ class Pipeline: pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) max_con_size = np.max(pop_node_sizes) if max_con_size >= self.C: + # self.N = int(self.N * self.expand_coe) self.C = int(self.C * self.expand_coe) print(f"connections expand to {self.C}!") self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) @@ -134,6 +146,7 @@ class Pipeline: new_timestamp = time.time() cost_time = new_timestamp - self.generation_timestamp + self.generation_time_list.append(cost_time) self.generation_timestamp = new_timestamp max_idx = np.argmax(fitnesses) diff --git a/examples/enhane_xor.py b/examples/enhane_xor.py new file mode 100644 index 0000000..cdd205c --- /dev/null +++ b/examples/enhane_xor.py @@ -0,0 +1,44 @@ +import numpy as np +import jax +from utils import Configer +from algorithms.neat import Pipeline +from time_utils import using_cprofile +from algorithms.neat.function_factory import FunctionFactory +from problems import EnhanceLogic +import time + + +def evaluate(problem, func): + inputs = problem.ask_for_inputs() + pop_predict = jax.device_get(func(inputs)) + # print(pop_predict) + fitnesses = [] + for predict in pop_predict: + f = problem.evaluate_predict(predict) + fitnesses.append(f) + return np.array(fitnesses) + +# @using_cprofile +# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/") +def main(): + tic = time.time() + config = Configer.load_config() + problem = EnhanceLogic("xor", n=3) + problem.refactor_config(config) + function_factory = FunctionFactory(config) + evaluate_func = lambda func: evaluate(problem, func) + pipeline = Pipeline(config, function_factory, seed=33413) + print("start run") + pipeline.auto_run(evaluate_func) + + total_time = time.time() - tic + compile_time = pipeline.function_factory.compile_time + total_it = pipeline.generation + mean_time_per_it = (total_time - compile_time) / total_it + evaluate_time = pipeline.evaluate_time + print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s") + print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s") + + +if __name__ == '__main__': + main() diff --git a/examples/final_design_experiement.py b/examples/final_design_experiement.py new file mode 100644 index 0000000..1c29b91 --- /dev/null +++ b/examples/final_design_experiement.py @@ -0,0 +1,56 @@ +import numpy as np +import jax +from utils import Configer +from algorithms.neat import Pipeline +from time_utils import using_cprofile +from algorithms.neat.function_factory import FunctionFactory +from problems import EnhanceLogic +import time + + +def evaluate(problem, func): + outs = func(problem.inputs) + outs = jax.device_get(outs) + fitnesses = -np.mean((problem.outputs - outs) ** 2, axis=(1, 2)) + return fitnesses + + +def main(): + config = Configer.load_config() + problem = EnhanceLogic("xor", n=3) + problem.refactor_config(config) + function_factory = FunctionFactory(config) + evaluate_func = lambda func: evaluate(problem, func) + + # precompile + pipeline = Pipeline(config, function_factory, seed=114514) + pipeline.auto_run(evaluate_func) + + for r in range(10): + print(f"running: {r}/{10}") + tic = time.time() + + pipeline = Pipeline(config, function_factory, seed=r) + pipeline.auto_run(evaluate_func) + + total_time = time.time() - tic + evaluate_time = pipeline.evaluate_time + total_it = pipeline.generation + print(f"total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}") + + if total_it >= 500: + res = "fail" + else: + res = "success" + + with open("log", "wb") as f: + f.write(f"{res}, total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}\n".encode("utf-8")) + f.write(str(pipeline.generation_time_list).encode("utf-8")) + + compile_time = function_factory.compile_time + + print("total_compile_time:", compile_time) + + +if __name__ == '__main__': + main() diff --git a/examples/xor.py b/examples/xor.py index 61d7398..b61f5ba 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,19 +1,27 @@ -from functools import partial - from utils import Configer from algorithms.neat import Pipeline from time_utils import using_cprofile from problems import Sin, Xor, DIY +import time -@using_cprofile +# @using_cprofile # @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/") def main(): + tic = time.time() config = Configer.load_config() problem = Xor() problem.refactor_config(config) - pipeline = Pipeline(config, seed=1) - pipeline.auto_run(problem.evaluate) + pipeline = Pipeline(config, seed=6) + nodes, cons = pipeline.auto_run(problem.evaluate) + # print(nodes, cons) + total_time = time.time() - tic + compile_time = pipeline.function_factory.compile_time + total_it = pipeline.generation + mean_time_per_it = (total_time - compile_time) / total_it + evaluate_time = pipeline.evaluate_time + print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s") + print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s") if __name__ == '__main__': diff --git a/problems/function_fitting/__init__.py b/problems/function_fitting/__init__.py index 8461221..85b1f28 100644 --- a/problems/function_fitting/__init__.py +++ b/problems/function_fitting/__init__.py @@ -1,4 +1,5 @@ from .function_fitting_problem import FunctionFittingProblem from .xor import * from .sin import * -from .diy import * \ No newline at end of file +from .diy import * +from .enhance_logic import * \ No newline at end of file diff --git a/problems/function_fitting/enhance_logic.py b/problems/function_fitting/enhance_logic.py new file mode 100644 index 0000000..c144c72 --- /dev/null +++ b/problems/function_fitting/enhance_logic.py @@ -0,0 +1,54 @@ +""" +xor problem in multiple dimensions +""" + +from itertools import product +import numpy as np + + +class EnhanceLogic: + def __init__(self, name="xor", n=2): + self.name = name + self.n = n + self.num_inputs = n + self.num_outputs = 1 + self.batch = 2 ** n + self.forward_way = 'pop_batch' + + self.inputs = np.array(generate_permutations(n), dtype=np.float32) + + if self.name == "xor": + self.outputs = np.sum(self.inputs, axis=1) % 2 + elif self.name == "and": + self.outputs = np.all(self.inputs==1, axis=1) + elif self.name == "or": + self.outputs = np.any(self.inputs==1, axis=1) + else: + raise NotImplementedError("Only support xor, and, or") + self.outputs = self.outputs[:, np.newaxis] + + + def refactor_config(self, config): + config.basic.forward_way = self.forward_way + config.basic.num_inputs = self.num_inputs + config.basic.num_outputs = self.num_outputs + config.basic.problem_batch = self.batch + + + def ask_for_inputs(self): + return self.inputs + + def evaluate_predict(self, predict): + # print((predict - self.outputs) ** 2) + return -np.mean((predict - self.outputs) ** 2) + + + +def generate_permutations(n): + permutations = [list(i) for i in product([0, 1], repeat=n)] + + return permutations + + +if __name__ == '__main__': + _ = EnhanceLogic(4) diff --git a/utils/default_config.json b/utils/default_config.json index 27aaa70..ab1c691 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -13,9 +13,9 @@ "neat": { "population": { "fitness_criterion": "max", - "fitness_threshold": -0.001, - "generation_limit": 1000, - "pop_size": 1000, + "fitness_threshold": -1e-2, + "generation_limit": 500, + "pop_size": 5000, "reset_on_extinction": "False" }, "gene": { @@ -35,7 +35,7 @@ }, "activation": { "default": "sigmoid", - "options": "sigmoid", + "options": ["sigmoid"], "mutate_rate": 0.1 }, "aggregation": { @@ -58,13 +58,13 @@ "compatibility_disjoint_coefficient": 1.0, "compatibility_weight_coefficient": 0.5, "single_structural_mutation": "False", - "conn_add_prob": 0.5, + "conn_add_prob": 0.6, "conn_delete_prob": 0, - "node_add_prob": 0.2, + "node_add_prob": 0.3, "node_delete_prob": 0 }, "species": { - "compatibility_threshold": 3, + "compatibility_threshold": 2.5, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,