From 0fdc856f2df071d2ca92a8a639bb9dc11ad2d9ba Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 9 May 2023 02:55:47 +0800 Subject: [PATCH] add function to put **all** compilation at the beginning of the execution. --- algorithms/neat/function_factory.py | 104 ++++++++++++++++++++++ algorithms/neat/genome/__init__.py | 3 +- algorithms/neat/genome/forward.py | 130 ++++++++++++++-------------- algorithms/neat/pipeline.py | 14 +-- examples/xor.py | 2 +- utils/default_config.json | 5 +- 6 files changed, 183 insertions(+), 75 deletions(-) diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index d29d0c0..27cc42c 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -7,6 +7,7 @@ import numpy as np from jax import jit, vmap from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover +from .genome import topological_sort, forward_single class FunctionFactory: @@ -24,6 +25,8 @@ class FunctionFactory: pass def load_config_vals(self, config): + self.problem_batch = config.basic.problem_batch + self.pop_size = config.neat.population.pop_size self.init_N = config.basic.init_maximum_nodes @@ -98,12 +101,17 @@ class FunctionFactory: self.create_mutate_with_args() self.create_distance_with_args() self.create_crossover_with_args() + self.create_topological_sort_with_args() + self.create_single_forward_with_args() + n = self.init_N print("start precompile") for _ in range(self.precompile_times): self.compile_mutate(n) self.compile_distance(n) self.compile_crossover(n) + self.compile_topological_sort(n) + self.compile_pop_batch_forward(n) n = int(self.expand_coe * n) print("end precompile") @@ -209,3 +217,99 @@ class FunctionFactory: if key not in self.compiled_function: self.compile_crossover(n) return self.compiled_function[key] + + def create_topological_sort_with_args(self): + self.topological_sort_with_args = topological_sort + + def compile_topological_sort(self, n): + func = self.topological_sort_with_args + func = vmap(func) + nodes_lower = np.zeros((self.pop_size, n, 5)) + connections_lower = np.zeros((self.pop_size, 2, n, n)) + func = jit(func).lower(nodes_lower, connections_lower).compile() + self.compiled_function[('topological_sort', n)] = func + + def create_topological_sort(self, n): + key = ('topological_sort', n) + if key not in self.compiled_function: + self.compile_topological_sort(n) + return self.compiled_function[key] + + def create_single_forward_with_args(self): + func = partial( + forward_single, + input_idx=self.input_idx, + output_idx=self.output_idx + ) + self.single_forward_with_args = func + + def compile_single_forward(self, n): + """ + single input for a genome + :param n: + :return: + """ + func = self.single_forward_with_args + inputs_lower = np.zeros((self.num_inputs,)) + cal_seqs_lower = np.zeros((n,), dtype=np.int32) + nodes_lower = np.zeros((n, 5)) + connections_lower = np.zeros((2, n, n)) + func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() + self.compiled_function[('single_forward', n)] = func + + def compile_pop_forward(self, n): + func = self.single_forward_with_args + func = vmap(func, in_axes=(None, 0, 0, 0)) + + inputs_lower = np.zeros((self.num_inputs,)) + cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32) + nodes_lower = np.zeros((self.pop_size, n, 5)) + connections_lower = np.zeros((self.pop_size, 2, n, n)) + func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() + self.compiled_function[('pop_forward', n)] = func + + def compile_batch_forward(self, n): + func = self.single_forward_with_args + func = vmap(func, in_axes=(0, None, None, None)) + + inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) + cal_seqs_lower = np.zeros((n,), dtype=np.int32) + nodes_lower = np.zeros((n, 5)) + connections_lower = np.zeros((2, n, n)) + func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() + self.compiled_function[('batch_forward', n)] = func + + def compile_pop_batch_forward(self, n): + func = self.single_forward_with_args + func = vmap(func, in_axes=(0, None, None, None)) # batch_forward + func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward + + inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) + cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32) + nodes_lower = np.zeros((self.pop_size, n, 5)) + connections_lower = np.zeros((self.pop_size, 2, n, n)) + + func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() + self.compiled_function[('pop_batch_forward', n)] = func + + def create_pop_batch_forward(self, n): + key = ('pop_batch_forward', n) + if key not in self.compiled_function: + self.compile_pop_batch_forward(n) + return self.compiled_function[key] + + def ask(self, pop_nodes, pop_connections): + n = pop_nodes.shape[1] + ts = self.create_topological_sort(n) + pop_cal_seqs = ts(pop_nodes, pop_connections) + + forward_func = self.create_pop_batch_forward(n) + + return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections) + + # return partial( + # forward_func, + # cal_seqs=pop_cal_seqs, + # nodes=pop_nodes, + # connections=pop_connections + # ) diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index f276f13..1bc92c6 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,7 +1,8 @@ from .genome import expand, expand_single, pop_analysis, initialize_genomes -from .forward import create_forward_function +from .forward import create_forward_function, forward_single from .activations import act_name2key from .aggregations import agg_name2key from .crossover import crossover from .mutate import mutate from .distance import distance +from .graph import topological_sort diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index dc85fd4..9aa22c4 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -46,15 +46,14 @@ def create_forward_function(nodes: NDArray, connections: NDArray, raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") -@partial(jit, static_argnames=['N']) -def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, - cal_seqs: Array, nodes: Array, connections: Array) -> Array: +@jit +def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array, + input_idx: Array, output_idx: Array) -> Array: """ jax forward for single input shaped (input_num, ) nodes, connections are single genome :argument inputs: (input_num, ) - :argument N: int :argument input_idx: (input_num, ) :argument output_idx: (output_num, ) :argument cal_seqs: (N, ) @@ -63,6 +62,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, :return (output_num, ) """ + N = nodes.shape[0] ini_vals = jnp.full((N,), jnp.nan) ini_vals = ini_vals.at[input_idx].set(inputs) @@ -86,64 +86,64 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, return vals[output_idx] -@partial(jit, static_argnames=['N']) -@partial(vmap, in_axes=(0, None, None, None, None, None, None)) -def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, - cal_seqs: Array, nodes: Array, connections: Array) -> Array: - """ - jax forward for batch_inputs shaped (batch_size, input_num) - nodes, connections are single genome - - :argument batch_inputs: (batch_size, input_num) - :argument N: int - :argument input_idx: (input_num, ) - :argument output_idx: (output_num, ) - :argument cal_seqs: (N, ) - :argument nodes: (N, 5) - :argument connections: (2, N, N) - - :return (batch_size, output_num) - """ - return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) - - -@partial(jit, static_argnames=['N']) -@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) -def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, - pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: - """ - jax forward for single input shaped (input_num, ) - pop_nodes, pop_connections are population of genomes - - :argument inputs: (input_num, ) - :argument N: int - :argument input_idx: (input_num, ) - :argument output_idx: (output_num, ) - :argument pop_cal_seqs: (pop_size, N) - :argument pop_nodes: (pop_size, N, 5) - :argument pop_connections: (pop_size, 2, N, N) - - :return (pop_size, output_num) - """ - return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) - - -@partial(jit, static_argnames=['N']) -@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) -def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, - pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: - """ - jax forward for batch input shaped (batch, input_num) - pop_nodes, pop_connections are population of genomes - - :argument batch_inputs: (batch_size, input_num) - :argument N: int - :argument input_idx: (input_num, ) - :argument output_idx: (output_num, ) - :argument pop_cal_seqs: (pop_size, N) - :argument pop_nodes: (pop_size, N, 5) - :argument pop_connections: (pop_size, 2, N, N) - - :return (pop_size, batch_size, output_num) - """ - return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) +# @partial(jit, static_argnames=['N']) +# @partial(vmap, in_axes=(0, None, None, None, None, None, None)) +# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, +# cal_seqs: Array, nodes: Array, connections: Array) -> Array: +# """ +# jax forward for batch_inputs shaped (batch_size, input_num) +# nodes, connections are single genome +# +# :argument batch_inputs: (batch_size, input_num) +# :argument N: int +# :argument input_idx: (input_num, ) +# :argument output_idx: (output_num, ) +# :argument cal_seqs: (N, ) +# :argument nodes: (N, 5) +# :argument connections: (2, N, N) +# +# :return (batch_size, output_num) +# """ +# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) +# +# +# @partial(jit, static_argnames=['N']) +# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) +# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, +# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: +# """ +# jax forward for single input shaped (input_num, ) +# pop_nodes, pop_connections are population of genomes +# +# :argument inputs: (input_num, ) +# :argument N: int +# :argument input_idx: (input_num, ) +# :argument output_idx: (output_num, ) +# :argument pop_cal_seqs: (pop_size, N) +# :argument pop_nodes: (pop_size, N, 5) +# :argument pop_connections: (pop_size, 2, N, N) +# +# :return (pop_size, output_num) +# """ +# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) +# +# +# @partial(jit, static_argnames=['N']) +# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) +# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, +# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: +# """ +# jax forward for batch input shaped (batch, input_num) +# pop_nodes, pop_connections are population of genomes +# +# :argument batch_inputs: (batch_size, input_num) +# :argument N: int +# :argument input_idx: (input_num, ) +# :argument output_idx: (output_num, ) +# :argument pop_cal_seqs: (pop_size, N) +# :argument pop_nodes: (pop_size, N, 5) +# :argument pop_connections: (pop_size, 2, N, N) +# +# :return (pop_size, batch_size, output_num) +# """ +# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index a1160c6..ae4d097 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -37,16 +37,18 @@ class Pipeline: self.best_fitness = float('-inf') self.generation_timestamp = time.time() - def ask(self, batch: bool): + def ask(self): """ Create a forward function for the population. - :param batch: :return: Algorithm gives the population a forward function, then environment gives back the fitnesses. """ - func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx, - batch=batch) - return func + return self.function_factory.ask(self.pop_nodes, self.pop_connections) + + # + # func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx, + # batch=batch) + # return func def tell(self, fitnesses): @@ -65,7 +67,7 @@ class Pipeline: def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config.neat.population.generation_limit): - forward_func = self.ask(batch=True) + forward_func = self.ask() fitnesses = fitness_func(forward_func) if analysis is not None: diff --git a/examples/xor.py b/examples/xor.py index 209f169..13c40cc 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -8,7 +8,7 @@ from utils import Configer from algorithms.neat import Pipeline from time_utils import using_cprofile -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]]) diff --git a/utils/default_config.json b/utils/default_config.json index 0bcdb21..b10133d 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,9 +2,10 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 10, + "problem_batch": 4, + "init_maximum_nodes": 20, "expands_coe": 2, - "pre_compile_times": 3 + "pre_compile_times": 0 }, "neat": { "population": {