From b271a56827f0f60c2a62b10beb47e55e5a3bf22c Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 10 May 2023 22:33:51 +0800 Subject: [PATCH] update to test in servers --- algorithms/neat/function_factory.py | 53 ++++++++++++++----- algorithms/neat/genome/activations.py | 1 + algorithms/neat/pipeline.py | 17 ++++-- examples/xor.py | 15 ++++-- problems/function_fitting/__init__.py | 1 + problems/function_fitting/diy.py | 14 +++++ .../function_fitting_problem.py | 25 +++++++-- problems/function_fitting/sin.py | 2 +- utils/default_config.json | 18 +++---- 9 files changed, 112 insertions(+), 34 deletions(-) create mode 100644 problems/function_fitting/diy.py diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index d024d9c..70c0b71 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -114,7 +114,7 @@ class FunctionFactory: self.compile_mutate(n) self.compile_distance(n) self.compile_crossover(n) - self.compile_topological_sort(n) + self.compile_topological_sort_batch(n) self.compile_pop_batch_forward(n) n = int(self.expand_coe * n) @@ -259,9 +259,8 @@ class FunctionFactory: def compile_topological_sort(self, n): func = self.topological_sort_with_args - func = vmap(func) - nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, 2, n, n)) + nodes_lower = np.zeros((n, 5)) + connections_lower = np.zeros((2, n, n)) func = jit(func).lower(nodes_lower, connections_lower).compile() self.compiled_function[('topological_sort', n)] = func @@ -271,6 +270,20 @@ class FunctionFactory: self.compile_topological_sort(n) return self.compiled_function[key] + def compile_topological_sort_batch(self, n): + func = self.topological_sort_with_args + func = vmap(func) + nodes_lower = np.zeros((self.pop_size, n, 5)) + connections_lower = np.zeros((self.pop_size, 2, n, n)) + func = jit(func).lower(nodes_lower, connections_lower).compile() + self.compiled_function[('topological_sort_batch', n)] = func + + def create_topological_sort_batch(self, n): + key = ('topological_sort_batch', n) + if key not in self.compiled_function: + self.compile_topological_sort_batch(n) + return self.compiled_function[key] + def create_single_forward_with_args(self): func = partial( forward_single, @@ -315,6 +328,18 @@ class FunctionFactory: func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() self.compiled_function[('batch_forward', n)] = func + def create_batch_forward(self, n): + key = ('batch_forward', n) + if key not in self.compiled_function: + self.compile_batch_forward(n) + if self.debug: + def debug_batch_forward(*args): + return self.compiled_function[key](*args).block_until_ready() + + return debug_batch_forward + else: + return self.compiled_function[key] + def compile_pop_batch_forward(self, n): func = self.single_forward_with_args func = vmap(func, in_axes=(0, None, None, None)) # batch_forward @@ -340,9 +365,9 @@ class FunctionFactory: else: return self.compiled_function[key] - def ask(self, pop_nodes, pop_connections): + def ask_pop_batch_forward(self, pop_nodes, pop_connections): n = pop_nodes.shape[1] - ts = self.create_topological_sort(n) + ts = self.create_topological_sort_batch(n) pop_cal_seqs = ts(pop_nodes, pop_connections) forward_func = self.create_pop_batch_forward(n) @@ -352,9 +377,13 @@ class FunctionFactory: return debug_forward - # return partial( - # forward_func, - # cal_seqs=pop_cal_seqs, - # nodes=pop_nodes, - # connections=pop_connections - # ) + def ask_batch_forward(self, nodes, connections): + n = nodes.shape[0] + ts = self.create_topological_sort(n) + cal_seqs = ts(nodes, connections) + forward_func = self.create_batch_forward(n) + + def debug_forward(inputs): + return forward_func(inputs, cal_seqs, nodes, connections) + + return debug_forward diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py index 48df72f..89e0f6a 100644 --- a/algorithms/neat/genome/activations.py +++ b/algorithms/neat/genome/activations.py @@ -68,6 +68,7 @@ def clamped_act(z): @jit def inv_act(z): + z = jnp.maximum(z, 1e-7) return 1 / z diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 6dcbaab..ac3a2bc 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -7,7 +7,6 @@ import numpy as np from .species import SpeciesController from .genome import expand, expand_single from .function_factory import FunctionFactory -from examples.time_utils import using_cprofile class Pipeline: @@ -16,7 +15,9 @@ class Pipeline: """ def __init__(self, config, seed=42): + self.time_dict = {} self.function_factory = FunctionFactory(config, debug=True) + self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) @@ -35,6 +36,7 @@ class Pipeline: self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) self.best_fitness = float('-inf') + self.best_genome = None self.generation_timestamp = time.time() def ask(self): @@ -43,7 +45,7 @@ class Pipeline: :return: Algorithm gives the population a forward function, then environment gives back the fitnesses. """ - return self.function_factory.ask(self.pop_nodes, self.pop_connections) + return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections) def tell(self, fitnesses): @@ -72,10 +74,14 @@ class Pipeline: assert callable(analysis), f"What the fuck you passed in? A {analysis}?" analysis(fitnesses) + if max(fitnesses) >= self.config.neat.population.fitness_threshold: + print("Fitness limit reached!") + return self.best_genome + self.tell(fitnesses) print("Generation limit reached!") + return self.best_genome - # @using_cprofile def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: """ create the next generation @@ -152,5 +158,10 @@ class Pipeline: cost_time = new_timestamp - self.generation_timestamp self.generation_timestamp = new_timestamp + max_idx = np.argmax(fitnesses) + if fitnesses[max_idx] > self.best_fitness: + self.best_fitness = fitnesses[max_idx] + self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx]) + print(f"Generation: {self.generation}", f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") diff --git a/examples/xor.py b/examples/xor.py index 2bd48ab..7bba04c 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -3,7 +3,7 @@ from functools import partial from utils import Configer from algorithms.neat import Pipeline from time_utils import using_cprofile -from problems import Sin, Xor +from problems import Sin, Xor, DIY # xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) @@ -25,11 +25,16 @@ from problems import Sin, Xor @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() - # problem = Xor() - problem = Sin() + config.neat.population.pop_size = 50 + problem = Xor() + # problem = Sin() + # problem = DIY(func=lambda x: (np.sin(x) + np.exp(x) - x ** 2) / (np.cos(x) + np.sqrt(x)) - np.log(x + 1)) problem.refactor_config(config) - pipeline = Pipeline(config, seed=11454) - pipeline.auto_run(problem.evaluate) + pipeline = Pipeline(config, seed=0) + best_nodes, best_connections = pipeline.auto_run(problem.evaluate) + # print(best_nodes, best_connections) + # func = pipeline.function_factory.ask_batch_forward(best_nodes, best_connections) + # problem.print(func) if __name__ == '__main__': diff --git a/problems/function_fitting/__init__.py b/problems/function_fitting/__init__.py index 4d597e0..8461221 100644 --- a/problems/function_fitting/__init__.py +++ b/problems/function_fitting/__init__.py @@ -1,3 +1,4 @@ from .function_fitting_problem import FunctionFittingProblem from .xor import * from .sin import * +from .diy import * \ No newline at end of file diff --git a/problems/function_fitting/diy.py b/problems/function_fitting/diy.py new file mode 100644 index 0000000..bc938d0 --- /dev/null +++ b/problems/function_fitting/diy.py @@ -0,0 +1,14 @@ +import numpy as np + +from . import FunctionFittingProblem + + +class DIY(FunctionFittingProblem): + def __init__(self, func, size=100): + self.num_inputs = 1 + self.num_outputs = 1 + self.batch = size + self.inputs = np.linspace(0, 1, self.batch)[:, None] + self.target = func(self.inputs) + print(self.inputs, self.target) + super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target) diff --git a/problems/function_fitting/function_fitting_problem.py b/problems/function_fitting/function_fitting_problem.py index c06132d..6a1bfd9 100644 --- a/problems/function_fitting/function_fitting_problem.py +++ b/problems/function_fitting/function_fitting_problem.py @@ -15,8 +15,25 @@ class FunctionFittingProblem(Problem): self.loss = loss super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch) - def evaluate(self, batch_forward_func): - out = batch_forward_func(self.inputs) - out = jax.device_get(out) - fitnesses = 1 - np.mean((self.target - out) ** 2, axis=(1, 2)) + def evaluate(self, pop_batch_forward): + outs = pop_batch_forward(self.inputs) + outs = jax.device_get(outs) + fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2)) return fitnesses.tolist() + + def draw(self, batch_func): + outs = batch_func(self.inputs) + outs = jax.device_get(outs) + print(outs) + from matplotlib import pyplot as plt + plt.xlabel('x') + plt.ylabel('y') + plt.plot(self.inputs, self.target, color='red', label='target') + plt.plot(self.inputs, outs, color='blue', label='predict') + plt.legend() + plt.show() + + def print(self, batch_func): + outs = batch_func(self.inputs) + outs = jax.device_get(outs) + print(outs) \ No newline at end of file diff --git a/problems/function_fitting/sin.py b/problems/function_fitting/sin.py index 2b1ee03..f1ac005 100644 --- a/problems/function_fitting/sin.py +++ b/problems/function_fitting/sin.py @@ -8,7 +8,7 @@ class Sin(FunctionFittingProblem): self.num_inputs = 1 self.num_outputs = 1 self.batch = size - self.inputs = np.linspace(0, np.pi, self.batch)[:, None] + self.inputs = np.linspace(0, 2 * np.pi, self.batch)[:, None] self.target = np.sin(self.inputs) print(self.inputs, self.target) super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target) diff --git a/utils/default_config.json b/utils/default_config.json index 273e304..ef5065e 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -11,9 +11,9 @@ "neat": { "population": { "fitness_criterion": "max", - "fitness_threshold": 76, - "generation_limit": 100, - "pop_size": 1000, + "fitness_threshold": -0.001, + "generation_limit": 1000, + "pop_size": 30, "reset_on_extinction": "False" }, "gene": { @@ -33,12 +33,12 @@ }, "activation": { "default": "sigmoid", - "options": ["sigmoid"], + "options": "sigmoid", "mutate_rate": 0.1 }, "aggregation": { "default": "sum", - "options": ["sum"], + "options": "sum", "mutate_rate": 0.1 }, "weight": { @@ -57,12 +57,12 @@ "compatibility_weight_coefficient": 0.5, "single_structural_mutation": "False", "conn_add_prob": 0.5, - "conn_delete_prob": 0, - "node_add_prob": 0.1, - "node_delete_prob": 0 + "conn_delete_prob": 0.5, + "node_add_prob": 0.2, + "node_delete_prob": 0.2 }, "species": { - "compatibility_threshold": 3, + "compatibility_threshold": 2.5, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,