From bd421de9ad1a84fafff01df791f4673360c8dc65 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 14 Jun 2023 10:20:55 +0800 Subject: [PATCH] Current Progress: After final design presentation --- examples/enhane_xor.py | 4 +- examples/final_design_experiement.py | 2 +- examples/final_design_experiment2.py | 52 +++++++++++++++++++ examples/jax_playground.py | 73 +++++++++++++++------------ examples/jitable_speciate_t.py | 75 ---------------------------- utils/default_config.json | 6 +-- 6 files changed, 99 insertions(+), 113 deletions(-) create mode 100644 examples/final_design_experiment2.py delete mode 100644 examples/jitable_speciate_t.py diff --git a/examples/enhane_xor.py b/examples/enhane_xor.py index cdd205c..3e37296 100644 --- a/examples/enhane_xor.py +++ b/examples/enhane_xor.py @@ -18,6 +18,7 @@ def evaluate(problem, func): fitnesses.append(f) return np.array(fitnesses) + # @using_cprofile # @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/") def main(): @@ -36,7 +37,8 @@ def main(): total_it = pipeline.generation mean_time_per_it = (total_time - compile_time) / total_it evaluate_time = pipeline.evaluate_time - print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s") + print( + f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s") print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s") diff --git a/examples/final_design_experiement.py b/examples/final_design_experiement.py index 1c29b91..d6a7c91 100644 --- a/examples/final_design_experiement.py +++ b/examples/final_design_experiement.py @@ -43,7 +43,7 @@ def main(): else: res = "success" - with open("log", "wb") as f: + with open("log", "ab") as f: f.write(f"{res}, total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}\n".encode("utf-8")) f.write(str(pipeline.generation_time_list).encode("utf-8")) diff --git a/examples/final_design_experiment2.py b/examples/final_design_experiment2.py new file mode 100644 index 0000000..84bf770 --- /dev/null +++ b/examples/final_design_experiment2.py @@ -0,0 +1,52 @@ + +import numpy as np +import jax +from utils import Configer +from algorithms.neat import Pipeline +from time_utils import using_cprofile +from algorithms.neat.function_factory import FunctionFactory +from problems import EnhanceLogic +import time + + +def evaluate(problem, func): + outs = func(problem.inputs) + outs = jax.device_get(outs) + fitnesses = -np.mean((problem.outputs - outs) ** 2, axis=(1, 2)) + return fitnesses + + +def main(): + config = Configer.load_config() + problem = EnhanceLogic("xor", n=3) + problem.refactor_config(config) + + evaluate_func = lambda func: evaluate(problem, func) + + for p in [100, 200, 500, 1000, 2000, 5000, 10000, 20000]: + config.neat.population.pop_size = p + tic = time.time() + function_factory = FunctionFactory(config) + print(f"running: {p}") + + pipeline = Pipeline(config, function_factory, seed=2) + pipeline.auto_run(evaluate_func) + + total_time = time.time() - tic + evaluate_time = pipeline.evaluate_time + total_it = pipeline.generation + print(f"total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}") + + with open("2060_log2", "ab") as f: + f.write \ + (f"{p}, total time: {total_time:.2f}s, compile time: {function_factory.compile_time:.2f}s, total_it: {total_it}\n".encode + ("utf-8")) + f.write(f"{str(pipeline.generation_time_list)}\n".encode("utf-8")) + + compile_time = function_factory.compile_time + + print("total_compile_time:", compile_time) + + +if __name__ == '__main__': + main() diff --git a/examples/jax_playground.py b/examples/jax_playground.py index e3e17b9..32ce960 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,50 +1,57 @@ import jax import jax.numpy as jnp from jax import jit, vmap -from time_utils import using_cprofile from time import time -# import numpy as np + + @jit -def fx(x): - return jnp.arange(x, x + 10) -# -# -# # @jit -# def fy(z): -# z1, z2 = z, z + 1 -# vmap_fx = vmap(fx) -# return vmap_fx(z1, z2) -# -# @jit -# def test_while(num, init_val): -# def cond_fun(carry): -# i, cumsum = carry -# return i < num -# -# def body_fun(carry): -# i, cumsum = carry -# cumsum += i -# return i + 1, cumsum -# -# return jax.lax.while_loop(cond_fun, body_fun, (0, init_val)) +def jax_mutate(seed, x): + noise = jax.random.normal(seed, x.shape) * 0.1 + return x + noise +def numpy_mutate(x): + noise = np.random.normal(size=x.shape) * 0.1 + return x + noise -# @using_cprofile +def jax_mutate_population(seed, pop_x): + seeds = jax.random.split(seed, len(pop_x)) + func = vmap(jax_mutate, in_axes=(0, 0)) + return func(seeds, pop_x) + + +def numpy_mutate_population(pop_x): + return np.stack([numpy_mutate(x) for x in pop_x]) + +def numpy_mutate_population_vmap(pop_x): + noise = np.random.normal(size=pop_x.shape) * 0.1 + return pop_x + noise + def main(): - print(fx(1)) + seed = jax.random.PRNGKey(0) + i = 10 + while i < 200000: + pop_x = jnp.ones((i, 100, 100)) + jax_pop_func = jit(jax_mutate_population).lower(seed, pop_x).compile() - # vmap_f = vmap(fx, in_axes=(None, 0)) - # vmap_vmap_f = vmap(vmap_f, in_axes=(0, None)) - # a = jnp.array([20,10,30]) - # b = jnp.array([6, 5, 4]) - # res = vmap_vmap_f(a, b) - # print(res) - # print(jnp.argmin(res, axis=1)) + tic = time() + res = jax.device_get(jax_pop_func(seed, pop_x)) + jax_time = time() - tic + tic = time() + res = numpy_mutate_population(pop_x) + numpy_time = time() - tic + tic = time() + res = numpy_mutate_population_vmap(pop_x) + numpy_time_vmap = time() - tic + + # print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Speedup: {numpy_time / jax_time:.4f}') + print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Numpy Vmap: {numpy_time_vmap:.4f}') + + i = int(i * 1.3) if __name__ == '__main__': main() diff --git a/examples/jitable_speciate_t.py b/examples/jitable_speciate_t.py deleted file mode 100644 index 9eae841..0000000 --- a/examples/jitable_speciate_t.py +++ /dev/null @@ -1,75 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np -from algorithms.neat.function_factory import FunctionFactory -from algorithms.neat.genome.debug.tools import check_array_valid -from utils import Configer -from algorithms.neat.population import speciate -from algorithms.neat.genome.crossover import crossover -from algorithms.neat.genome.utils import I_INT -from time import time - -if __name__ == '__main__': - config = Configer.load_config() - function_factory = FunctionFactory(config, debug=True) - initialize_func = function_factory.create_initialize() - - pop_nodes, pop_connections, input_idx, output_idx = initialize_func() - mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1]) - crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1]) - - N, C, species_size = function_factory.init_N, function_factory.init_C, 20 - spe_center_nodes = np.full((species_size, N, 5), np.nan) - spe_center_connections = np.full((species_size, C, 4), np.nan) - spe_center_nodes[0] = pop_nodes[0] - spe_center_connections[0] = pop_connections[0] - spe_keys = np.full((species_size,), I_INT) - spe_keys[0] = 0 - new_spe_key = 1 - key = jax.random.PRNGKey(0) - new_node_idx = 100 - - while True: - start_time = time() - key, subkey = jax.random.split(key) - mutate_keys = jax.random.split(subkey, len(pop_nodes)) - new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes)) - new_node_idx += len(pop_nodes) - pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes) - pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections]) - # for i in range(len(pop_nodes)): - # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) - idx1 = np.random.permutation(len(pop_nodes)) - idx2 = np.random.permutation(len(pop_nodes)) - - n1, c1 = pop_nodes[idx1], pop_connections[idx1] - n2, c2 = pop_nodes[idx2], pop_connections[idx2] - crossover_keys = jax.random.split(subkey, len(pop_nodes)) - pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2) - - - # for i in range(len(pop_nodes)): - # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) - - #speciate next generation - - idx2specie, spe_center_nodes, spe_center_cons, spe_keys, new_spe_key = speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections, - spe_keys, new_spe_key, - compatibility_threshold=3) - - print(spe_keys, new_spe_key) - - # - # idx2specie = np.array(idx2specie) - # spe_dict = {} - # for i in range(len(idx2specie)): - # spe_idx = idx2specie[i] - # if spe_idx not in spe_dict: - # spe_dict[spe_idx] = 1 - # else: - # spe_dict[spe_idx] += 1 - # - # print(spe_dict) - # assert np.all(idx2specie != I_INT) - print(time() - start_time) - # print(idx2specie) diff --git a/utils/default_config.json b/utils/default_config.json index ab1c691..0071d57 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -13,9 +13,9 @@ "neat": { "population": { "fitness_criterion": "max", - "fitness_threshold": -1e-2, - "generation_limit": 500, - "pop_size": 5000, + "fitness_threshold": 1e-2, + "generation_limit": 100, + "pop_size": 1000, "reset_on_extinction": "False" }, "gene": {