From 0cb2f9473d33659047ef831675e5c92ecbd73adb Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 25 Jun 2023 00:26:52 +0800 Subject: [PATCH] finish ask part of the algorithm; use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward --- configs/activations.py | 32 --- configs/aggregations.py | 20 -- configs/configer.py | 34 ++- configs/default_config.ini | 7 +- examples/a.py | 9 +- examples/jax_playground.py | 15 +- examples/time_utils.py | 28 -- examples/xor.ini | 3 + neat/function_factory.py | 381 ++++++-------------------- neat/genome/activations.py | 28 +- neat/genome/aggregations.py | 53 +--- neat/genome/crossover.py | 17 +- neat/genome/crossover_.py | 81 ------ neat/genome/distance.py | 82 +++--- neat/genome/distance_.py | 119 -------- neat/genome/forward.py | 99 ++++--- neat/genome/{genome_.py => genome.py} | 14 +- neat/genome/graph.py | 76 +++-- neat/genome/mutate.py | 291 +++++++------------- neat/genome/mutate_.py | 355 ------------------------ neat/genome/utils.py | 53 +--- neat/genome/utils_.py | 102 ------- neat/pipeline.py | 182 ++++-------- neat/pipeline_.py | 27 -- 24 files changed, 485 insertions(+), 1623 deletions(-) delete mode 100644 configs/activations.py delete mode 100644 configs/aggregations.py delete mode 100644 examples/time_utils.py delete mode 100644 neat/genome/crossover_.py delete mode 100644 neat/genome/distance_.py rename neat/genome/{genome_.py => genome.py} (86%) delete mode 100644 neat/genome/mutate_.py delete mode 100644 neat/genome/utils_.py delete mode 100644 neat/pipeline_.py diff --git a/configs/activations.py b/configs/activations.py deleted file mode 100644 index 677105f..0000000 --- a/configs/activations.py +++ /dev/null @@ -1,32 +0,0 @@ -from neat.genome.activations import * - -ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act, - identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act] - -act_name2key = { - 'sigmoid': 0, - 'tanh': 1, - 'sin': 2, - 'gauss': 3, - 'relu': 4, - 'elu': 5, - 'lelu': 6, - 'selu': 7, - 'softplus': 8, - 'identity': 9, - 'clamped': 10, - 'inv': 11, - 'log': 12, - 'exp': 13, - 'abs': 14, - 'hat': 15, - 'square': 16, - 'cube': 17, -} - - -def refactor_act(config): - config['activation_default'] = act_name2key[config['activation_default']] - config['activation_options'] = [ - act_name2key[act_name] for act_name in config['activation_options'] - ] diff --git a/configs/aggregations.py b/configs/aggregations.py deleted file mode 100644 index 439db7c..0000000 --- a/configs/aggregations.py +++ /dev/null @@ -1,20 +0,0 @@ -from neat.genome.aggregations import * - -AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg] - -agg_name2key = { - 'sum': 0, - 'product': 1, - 'max': 2, - 'min': 3, - 'maxabs': 4, - 'median': 5, - 'mean': 6, -} - - -def refactor_agg(config): - config['aggregation_default'] = agg_name2key[config['aggregation_default']] - config['aggregation_options'] = [ - agg_name2key[act_name] for act_name in config['aggregation_options'] - ] diff --git a/configs/configer.py b/configs/configer.py index b433256..9118a38 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -4,8 +4,8 @@ import configparser import numpy as np -from .activations import refactor_act -from .aggregations import refactor_agg +from neat.genome.activations import act_name2func +from neat.genome.aggregations import agg_name2func # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ @@ -20,12 +20,12 @@ jit_config_keys = [ "node_delete_prob", "compatibility_threshold", "bias_init_mean", - "bias_init_stdev", + "bias_init_std", "bias_mutate_power", "bias_mutate_rate", "bias_replace_rate", "response_init_mean", - "response_init_stdev", + "response_init_std", "response_mutate_power", "response_mutate_rate", "response_replace_rate", @@ -36,7 +36,7 @@ jit_config_keys = [ "aggregation_options", "aggregation_replace_rate", "weight_init_mean", - "weight_init_stdev", + "weight_init_std", "weight_mutate_power", "weight_mutate_rate", "weight_replace_rate", @@ -90,14 +90,26 @@ class Configer: cls.__check_redundant_config(default_config, config) cls.__complete_config(default_config, config) - refactor_act(config) - refactor_agg(config) - input_idx = np.arange(config['num_inputs']) - output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) - config['input_idx'] = input_idx - config['output_idx'] = output_idx + cls.refactor_activation(config) + cls.refactor_aggregation(config) + + config['input_idx'] = np.arange(config['num_inputs']) + config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) + return config + @classmethod + def refactor_activation(cls, config): + config['activation_default'] = 0 + config['activation_options'] = np.arange(len(config['activation_option_names'])) + config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']] + + @classmethod + def refactor_aggregation(cls, config): + config['aggregation_default'] = 0 + config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) + config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']] + @classmethod def create_jit_config(cls, config): jit_config = {k: config[k] for k in jit_config_keys} diff --git a/configs/default_config.ini b/configs/default_config.ini index f9cc260..b45b49f 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -5,7 +5,8 @@ init_maximum_nodes = 20 init_maximum_connections = 20 init_maximum_species = 10 expands_coe = 2.0 -forward_way = "pop_batch" +forward_way = "pop" +batch_size = 4 [population] fitness_threshold = 100000 @@ -46,12 +47,12 @@ response_replace_rate = 0.0 [gene-activation] activation_default = "sigmoid" -activation_options = ["sigmoid"] +activation_option_names = ["sigmoid"] activation_replace_rate = 0.0 [gene-aggregation] aggregation_default = "sum" -aggregation_options = ["sum"] +aggregation_option_names = ["sum"] aggregation_replace_rate = 0.0 [gene-weight] diff --git a/examples/a.py b/examples/a.py index 5e874ce..c7138d9 100644 --- a/examples/a.py +++ b/examples/a.py @@ -3,6 +3,9 @@ import numpy as np import jax.numpy as jnp import jax +a = {1:2, 2:3, 4:5} +print(a.values()) + a = jnp.array([1, 0, 1, 0, np.nan]) b = jnp.array([1, 1, 1, 1, 1]) c = jnp.array([1, 1, 1, 1, 1]) @@ -44,5 +47,9 @@ def func(x): else: return 2 +a = jnp.zeros((3, 3)) +print(a.dtype) -print(main()) \ No newline at end of file +c = None +b = 1 or c +print(b) \ No newline at end of file diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 55357ad..379c28a 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,16 +1,25 @@ from functools import partial +import numpy as np import jax from jax import jit from configs import Configer -from neat.pipeline_ import Pipeline +from neat.pipeline import Pipeline +from neat.function_factory import FunctionFactory +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) def main(): config = Configer.load_config("xor.ini") - print(config) - pipeline = Pipeline(config) + function_factory = FunctionFactory(config) + pipeline = Pipeline(config, function_factory) + forward_func = pipeline.ask() + # inputs = np.tile(xor_inputs, (150, 1, 1)) + outputs = forward_func(xor_inputs) + print(outputs) + @jit diff --git a/examples/time_utils.py b/examples/time_utils.py deleted file mode 100644 index 2fd6e6f..0000000 --- a/examples/time_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import cProfile -from io import StringIO -import pstats - - -def using_cprofile(func, root_abs_path=None, replace_pattern=None, save_path=None): - def inner(*args, **kwargs): - pr = cProfile.Profile() - pr.enable() - ret = func(*args, **kwargs) - pr.disable() - profile_stats = StringIO() - stats = pstats.Stats(pr, stream=profile_stats) - if root_abs_path is not None: - stats.sort_stats('cumulative').print_stats(root_abs_path) - else: - stats.sort_stats('cumulative').print_stats() - output = profile_stats.getvalue() - if replace_pattern is not None: - output = output.replace(replace_pattern, "") - if save_path is None: - print(output) - else: - with open(save_path, "w") as f: - f.write(output) - return ret - - return inner diff --git a/examples/xor.ini b/examples/xor.ini index b57272e..79a1110 100644 --- a/examples/xor.ini +++ b/examples/xor.ini @@ -1,2 +1,5 @@ +[basic] +forward_way = "common" + [population] fitness_threshold = -1e-2 \ No newline at end of file diff --git a/neat/function_factory.py b/neat/function_factory.py index 8c05bb3..effed6a 100644 --- a/neat/function_factory.py +++ b/neat/function_factory.py @@ -1,323 +1,108 @@ -""" -Lowers, compiles, and creates functions used in the NEAT pipeline. -""" -from functools import partial -import time - import numpy as np from jax import jit, vmap -from .genome import act_name2key, agg_name2key, initialize_genomes -from .genome import topological_sort, forward_single, unflatten_connections -from .population import create_next_generation_then_speciate +from .genome.forward import create_forward +from .genome.utils import unflatten_connections +from .genome.graph import topological_sort + + +def hash_symbols(symbols): + return symbols['P'], symbols['N'], symbols['C'], symbols['S'] class FunctionFactory: + """ + Creates and compiles functions used in the NEAT pipeline. + """ + def __init__(self, config): self.config = config + self.func_dict = {} + self.function_info = {} - self.expand_coe = config.basic.expands_coe - self.precompile_times = config.basic.pre_compile_times - self.compiled_function = {} - self.compile_time = 0 + # (inputs_nums, ) -> (outputs_nums, ) + forward = create_forward(config) # input size (inputs_nums, ) - self.load_config_vals(config) + # (batch_size, inputs_nums) -> (batch_size, outputs_nums) + batch_forward = vmap(forward, in_axes=(0, None, None, None)) - self.create_topological_sort_with_args() - self.create_single_forward_with_args() - self.create_update_speciate_with_args() + # (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) + pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0)) - def load_config_vals(self, config): - self.compatibility_threshold = self.config.neat.species.compatibility_threshold + # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) + common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) - self.problem_batch = config.basic.problem_batch - self.pop_size = config.neat.population.pop_size + self.function_info = { + "pop_unflatten_connections": { + 'func': vmap(unflatten_connections), + 'lowers': [ + {'shape': ('P', 'N', 5), 'type': np.float32}, + {'shape': ('P', 'C', 4), 'type': np.float32} + ] + }, - self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient - self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient + "pop_topological_sort": { + 'func': vmap(topological_sort), + 'lowers': [ + {'shape': ('P', 'N', 5), 'type': np.float32}, + {'shape': ('P', 2, 'N', 'N'), 'type': np.float32}, + ] + }, - self.num_inputs = config.basic.num_inputs - self.num_outputs = config.basic.num_outputs - self.input_idx = np.arange(self.num_inputs) - self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs) + "batch_forward": { + 'func': batch_forward, + 'lowers': [ + {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, + {'shape': ('N', ), 'type': np.int32}, + {'shape': ('N', 5), 'type': np.float32}, + {'shape': (2, 'N', 'N'), 'type': np.float32} + ] + }, - bias = config.neat.gene.bias - self.bias_mean = bias.init_mean - self.bias_std = bias.init_stdev - self.bias_mutate_strength = bias.mutate_power - self.bias_mutate_rate = bias.mutate_rate - self.bias_replace_rate = bias.replace_rate + "pop_batch_forward": { + 'func': pop_batch_forward, + 'lowers': [ + {'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32}, + {'shape': ('P', 'N'), 'type': np.int32}, + {'shape': ('P', 'N', 5), 'type': np.float32}, + {'shape': ('P', 2, 'N', 'N'), 'type': np.float32} + ] + }, - response = config.neat.gene.response - self.response_mean = response.init_mean - self.response_std = response.init_stdev - self.response_mutate_strength = response.mutate_power - self.response_mutate_rate = response.mutate_rate - self.response_replace_rate = response.replace_rate - - weight = config.neat.gene.weight - self.weight_mean = weight.init_mean - self.weight_std = weight.init_stdev - self.weight_mutate_strength = weight.mutate_power - self.weight_mutate_rate = weight.mutate_rate - self.weight_replace_rate = weight.replace_rate - - activation = config.neat.gene.activation - self.act_default = act_name2key[activation.default] - self.act_list = np.array([act_name2key[name] for name in activation.options]) - self.act_replace_rate = activation.mutate_rate - - aggregation = config.neat.gene.aggregation - self.agg_default = agg_name2key[aggregation.default] - self.agg_list = np.array([agg_name2key[name] for name in aggregation.options]) - self.agg_replace_rate = aggregation.mutate_rate - - enabled = config.neat.gene.enabled - self.enabled_reverse_rate = enabled.mutate_rate - - genome = config.neat.genome - self.add_node_rate = genome.node_add_prob - self.delete_node_rate = genome.node_delete_prob - self.add_connection_rate = genome.conn_add_prob - self.delete_connection_rate = genome.conn_delete_prob - self.single_structure_mutate = genome.single_structural_mutation - - def create_initialize(self, N, C): - func = partial( - initialize_genomes, - pop_size=self.pop_size, - N=N, - C=C, - num_inputs=self.num_inputs, - num_outputs=self.num_outputs, - default_bias=self.bias_mean, - default_response=self.response_mean, - default_act=self.act_default, - default_agg=self.agg_default, - default_weight=self.weight_mean - ) - return func - - def create_update_speciate_with_args(self): - species_kwargs = { - "disjoint_coe": self.disjoint_coe, - "compatibility_coe": self.compatibility_coe, - "compatibility_threshold": self.compatibility_threshold + 'common_forward': { + 'func': common_forward, + 'lowers': [ + {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, + {'shape': ('P', 'N'), 'type': np.int32}, + {'shape': ('P', 'N', 5), 'type': np.float32}, + {'shape': ('P', 2, 'N', 'N'), 'type': np.float32} + ] + } } - mutate_kwargs = { - "input_idx": self.input_idx, - "output_idx": self.output_idx, - "bias_mean": self.bias_mean, - "bias_std": self.bias_std, - "bias_mutate_strength": self.bias_mutate_strength, - "bias_mutate_rate": self.bias_mutate_rate, - "bias_replace_rate": self.bias_replace_rate, - "response_mean": self.response_mean, - "response_std": self.response_std, - "response_mutate_strength": self.response_mutate_strength, - "response_mutate_rate": self.response_mutate_rate, - "response_replace_rate": self.response_replace_rate, - "weight_mean": self.weight_mean, - "weight_std": self.weight_std, - "weight_mutate_strength": self.weight_mutate_strength, - "weight_mutate_rate": self.weight_mutate_rate, - "weight_replace_rate": self.weight_replace_rate, - "act_default": self.act_default, - "act_list": self.act_list, - "act_replace_rate": self.act_replace_rate, - "agg_default": self.agg_default, - "agg_list": self.agg_list, - "agg_replace_rate": self.agg_replace_rate, - "enabled_reverse_rate": self.enabled_reverse_rate, - "add_node_rate": self.add_node_rate, - "delete_node_rate": self.delete_node_rate, - "add_connection_rate": self.add_connection_rate, - "delete_connection_rate": self.delete_connection_rate, - } - self.update_speciate_with_args = partial( - create_next_generation_then_speciate, - species_kwargs=species_kwargs, - mutate_kwargs=mutate_kwargs - ) + def get(self, name, symbols): + if (name, hash_symbols(symbols)) not in self.func_dict: + self.compile(name, symbols) + return self.func_dict[name, hash_symbols(symbols)] - def create_update_speciate(self, N, C, S): - key = ("update_speciate", N, C, S) - if key not in self.compiled_function: - self.compile_update_speciate(N, C, S) - return self.compiled_function[key] + def compile(self, name, symbols): + # prepare function prototype + func = self.function_info[name]['func'] - def compile_update_speciate(self, N, C, S): - s = time.time() + # prepare lower operands + lowers_operands = [] + for lower in self.function_info[name]['lowers']: + shape = list(lower['shape']) + for i, s in enumerate(shape): + if s in symbols: + shape[i] = symbols[s] + assert isinstance(shape[i], int) + lowers_operands.append(np.zeros(shape, dtype=lower['type'])) - func = self.update_speciate_with_args - randkey_lower = np.zeros((2,), dtype=np.uint32) - pop_nodes_lower = np.zeros((self.pop_size, N, 5)) - pop_cons_lower = np.zeros((self.pop_size, C, 4)) - winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32) - loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32) - elite_mask_lower = np.zeros((self.pop_size,), dtype=bool) - new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32) - pre_spe_center_nodes_lower = np.zeros((S, N, 5)) - pre_spe_center_cons_lower = np.zeros((S, C, 4)) - species_keys = np.zeros((S,), dtype=np.int32) - new_species_keys_lower = 0 - compiled_func = jit(func).lower( - randkey_lower, - pop_nodes_lower, - pop_cons_lower, - winner_part_lower, - loser_part_lower, - elite_mask_lower, - new_node_keys_start_lower, - pre_spe_center_nodes_lower, - pre_spe_center_cons_lower, - species_keys, - new_species_keys_lower, - ).compile() - self.compiled_function[("update_speciate", N, C, S)] = compiled_func + # compile + compiled_func = jit(func).lower(*lowers_operands).compile() - self.compile_time += time.time() - s - - def create_topological_sort_with_args(self): - self.topological_sort_with_args = topological_sort - - def compile_topological_sort(self, n): - s = time.time() - - func = self.topological_sort_with_args - nodes_lower = np.zeros((n, 5)) - connections_lower = np.zeros((2, n, n)) - func = jit(func).lower(nodes_lower, connections_lower).compile() - self.compiled_function[('topological_sort', n)] = func - - self.compile_time += time.time() - s - - def create_topological_sort(self, n): - key = ('topological_sort', n) - if key not in self.compiled_function: - self.compile_topological_sort(n) - return self.compiled_function[key] - - def compile_topological_sort_batch(self, n): - s = time.time() - - func = self.topological_sort_with_args - func = vmap(func) - nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, 2, n, n)) - func = jit(func).lower(nodes_lower, connections_lower).compile() - self.compiled_function[('topological_sort_batch', n)] = func - - self.compile_time += time.time() - s - - def create_topological_sort_batch(self, n): - key = ('topological_sort_batch', n) - if key not in self.compiled_function: - self.compile_topological_sort_batch(n) - return self.compiled_function[key] - - def create_single_forward_with_args(self): - func = partial( - forward_single, - input_idx=self.input_idx, - output_idx=self.output_idx - ) - self.single_forward_with_args = func - - - def compile_batch_forward(self, n): - s = time.time() - - func = self.single_forward_with_args - func = vmap(func, in_axes=(0, None, None, None)) - - inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) - cal_seqs_lower = np.zeros((n,), dtype=np.int32) - nodes_lower = np.zeros((n, 5)) - connections_lower = np.zeros((2, n, n)) - func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() - self.compiled_function[('batch_forward', n)] = func - - self.compile_time += time.time() - s - - def create_batch_forward(self, n): - key = ('batch_forward', n) - if key not in self.compiled_function: - self.compile_batch_forward(n) - - return self.compiled_function[key] - - def compile_pop_batch_forward(self, n): - - s = time.time() - - func = self.single_forward_with_args - func = vmap(func, in_axes=(0, None, None, None)) # batch_forward - func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward - - inputs_lower = np.zeros((self.problem_batch, self.num_inputs)) - cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32) - nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, 2, n, n)) - - func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile() - self.compiled_function[('pop_batch_forward', n)] = func - - self.compile_time += time.time() - s - - def create_pop_batch_forward(self, n): - key = ('pop_batch_forward', n) - if key not in self.compiled_function: - self.compile_pop_batch_forward(n) - - return self.compiled_function[key] - - def ask_pop_batch_forward(self, pop_nodes, pop_cons): - n, c = pop_nodes.shape[1], pop_cons.shape[1] - batch_unflatten_func = self.create_batch_unflatten_connections(n, c) - pop_cons = batch_unflatten_func(pop_nodes, pop_cons) - ts = self.create_topological_sort_batch(n) - - # for connections with enabled is false, set weight to 0) - pop_cal_seqs = ts(pop_nodes, pop_cons) - # print(pop_cal_seqs) - forward_func = self.create_pop_batch_forward(n) - - def debug_forward(inputs): - return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons) - - return debug_forward - - def ask_batch_forward(self, nodes, connections): - n = nodes.shape[0] - ts = self.create_topological_sort(n) - cal_seqs = ts(nodes, connections) - forward_func = self.create_batch_forward(n) - - def debug_forward(inputs): - return forward_func(inputs, cal_seqs, nodes, connections) - - return debug_forward - - def compile_batch_unflatten_connections(self, n, c): - - s = time.time() - - func = unflatten_connections - func = vmap(func) - pop_nodes_lower = np.zeros((self.pop_size, n, 5)) - pop_connections_lower = np.zeros((self.pop_size, c, 4)) - func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile() - self.compiled_function[('batch_unflatten_connections', n, c)] = func - - self.compile_time += time.time() - s - - def create_batch_unflatten_connections(self, n, c): - key = ('batch_unflatten_connections', n, c) - if key not in self.compiled_function: - self.compile_batch_unflatten_connections(n, c) - - return self.compiled_function[key] + # save for reuse + self.func_dict[name, hash_symbols(symbols)] = compiled_func diff --git a/neat/genome/activations.py b/neat/genome/activations.py index 247be85..48bff71 100644 --- a/neat/genome/activations.py +++ b/neat/genome/activations.py @@ -104,11 +104,23 @@ def cube_act(z): return z ** 3 -@jit -def act(idx, z): - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - res = jax.lax.switch(idx, ACT_TOTAL_LIST, z) - return jnp.where(jnp.isnan(res), jnp.nan, res) - - # return jax.lax.switch(idx, ACT_TOTAL_LIST, z) +act_name2func = { + 'sigmoid': sigmoid_act, + 'tanh': tanh_act, + 'sin': sin_act, + 'gauss': gauss_act, + 'relu': relu_act, + 'elu': elu_act, + 'lelu': lelu_act, + 'selu': selu_act, + 'softplus': softplus_act, + 'identity': identity_act, + 'clamped': clamped_act, + 'inv': inv_act, + 'log': log_act, + 'exp': exp_act, + 'abs': abs_act, + 'hat': hat_act, + 'square': square_act, + 'cube': cube_act, +} diff --git a/neat/genome/aggregations.py b/neat/genome/aggregations.py index 119f175..ed221f1 100644 --- a/neat/genome/aggregations.py +++ b/neat/genome/aggregations.py @@ -1,9 +1,3 @@ -""" -aggregations, two special case need to consider: -1. extra 0s -2. full of 0s -""" - import jax import jax.numpy as jnp import numpy as np @@ -44,19 +38,13 @@ def maxabs_agg(z): @jit def median_agg(z): - non_zero_mask = ~jnp.isnan(z) - n = jnp.sum(non_zero_mask, axis=0) + non_nan_mask = ~jnp.isnan(z) + n = jnp.sum(non_nan_mask, axis=0) - z = jnp.where(jnp.isnan(z), jnp.inf, z) - sorted_valid_values = jnp.sort(z) + z = jnp.sort(z) # sort - def _even_case(): - return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2 - - def _odd_case(): - return sorted_valid_values[n // 2] - - median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case) + idx1, idx2 = (n - 1) // 2, n // 2 + median = (z[idx1] + z[idx2]) / 2 return median @@ -70,25 +58,12 @@ def mean_agg(z): return mean_without_zeros -@jit -def agg(idx, z): - idx = jnp.asarray(idx, dtype=jnp.int32) - - def full_nan(): - return 0. - - def not_full_nan(): - return jax.lax.switch(idx, AGG_TOTAL_LIST, z) - - return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan) - - - -if __name__ == '__main__': - array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32) - for names in agg_name2key.keys(): - print(names, agg(agg_name2key[names], array)) - - array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32) - for names in agg_name2key.keys(): - print(names, agg(agg_name2key[names], array2)) +agg_name2func = { + 'sum': sum_agg, + 'product': product_agg, + 'max': max_agg, + 'min': min_agg, + 'maxabs': maxabs_agg, + 'median': median_agg, + 'mean': mean_agg, +} diff --git a/neat/genome/crossover.py b/neat/genome/crossover.py index 0873b98..2a02d9b 100644 --- a/neat/genome/crossover.py +++ b/neat/genome/crossover.py @@ -1,14 +1,17 @@ -from functools import partial +""" +Crossover two genomes to generate a new genome. +The calculation method is the same as the crossover operation in NEAT-python. +See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover +""" from typing import Tuple import jax -from jax import jit, vmap, Array +from jax import jit, Array from jax import numpy as jnp @jit -def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \ - -> Tuple[Array, Array]: +def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]: """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) @@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + # make homologous genes align in nodes2 align with nodes1 nodes2 = align_array(keys1, keys2, nodes2, 'node') + + # For not homologous genes, use the value of nodes1(winner) + # For homologous genes, use the crossover result between nodes1 and nodes2 new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections @@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: return new_nodes, new_cons -# @partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ After I review this code, I found that it is the most difficult part of the code. Please never change it! @@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: return refactor_ar2 -# @jit def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: """ crossover two genes diff --git a/neat/genome/crossover_.py b/neat/genome/crossover_.py deleted file mode 100644 index 2a02d9b..0000000 --- a/neat/genome/crossover_.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -Crossover two genomes to generate a new genome. -The calculation method is the same as the crossover operation in NEAT-python. -See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover -""" -from typing import Tuple - -import jax -from jax import jit, Array -from jax import numpy as jnp - - -@jit -def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]: - """ - use genome1 and genome2 to generate a new genome - notice that genome1 should have higher fitness than genome2 (genome1 is winner!) - :param randkey: - :param nodes1: - :param cons1: - :param nodes2: - :param cons2: - :return: - """ - randkey_1, randkey_2 = jax.random.split(randkey) - - # crossover nodes - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] - # make homologous genes align in nodes2 align with nodes1 - nodes2 = align_array(keys1, keys2, nodes2, 'node') - - # For not homologous genes, use the value of nodes1(winner) - # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) - - # crossover connections - con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] - cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') - new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) - - return new_nodes, new_cons - - -def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: - """ - After I review this code, I found that it is the most difficult part of the code. Please never change it! - make ar2 align with ar1. - :param seq1: - :param seq2: - :param ar2: - :param gene_type: - :return: - align means to intersect part of ar2 will be at the same position as ar1, - non-intersect part of ar2 will be set to Nan - """ - seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] - mask = (seq1 == seq2) & (~jnp.isnan(seq1)) - - if gene_type == 'connection': - mask = jnp.all(mask, axis=2) - - intersect_mask = mask.any(axis=1) - idx = jnp.arange(0, len(seq1)) - idx_fixed = jnp.dot(mask, idx) - - refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) - - return refactor_ar2 - - -def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ - r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) diff --git a/neat/genome/distance.py b/neat/genome/distance.py index 2a6519a..69e421e 100644 --- a/neat/genome/distance.py +++ b/neat/genome/distance.py @@ -1,6 +1,9 @@ """ Calculate the distance between two genomes. +The calculation method is the same as the distance calculation in NEAT-python. +See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py """ +from typing import Dict from jax import jit, vmap, Array from jax import numpy as jnp @@ -9,26 +12,34 @@ from .utils import EMPTY_NODE, EMPTY_CON @jit -def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1., - compatibility_coe: float = 0.5) -> Array: +def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array: """ Calculate the distance between two genomes. - nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg] - connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] + args: + nodes1: Array(N, 5) + cons1: Array(C, 4) + nodes2: Array(N, 5) + cons2: Array(C, 4) + returns: + distance: Array(, ) """ - - nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance - - cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance + nd = node_distance(nodes1, nodes2, jit_config) # node distance + cd = connection_distance(cons1, cons2, jit_config) # connection distance return nd + cd @jit -def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): +def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): + """ + Calculate the distance between nodes of two genomes. + """ + # statistics nodes count of two genomes node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) max_cnt = jnp.maximum(node_cnt1, node_cnt2) + # align homologous nodes + # this process is similar to np.intersect1d. nodes = jnp.concatenate((nodes1, nodes2), axis=0) keys = nodes[:, 0] sorted_indices = jnp.argsort(keys, axis=0) @@ -36,19 +47,29 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end fr, sr = nodes[:-1], nodes[1:] # first row, second row + # flag location of homologous nodes intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + # calculate the count of non_homologous of two genomes non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - nd = batch_homologous_node_distance(fr, sr) - nd = jnp.where(jnp.isnan(nd), 0, nd) - homologous_distance = jnp.sum(nd * intersect_mask) - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - return jnp.where(max_cnt == 0, 0, val / max_cnt) + # calculate the distance of homologous nodes + hnd = vmap(homologous_node_distance)(fr, sr) + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) + + val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ + 'compatibility_weight'] + + return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division @jit -def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): +def connection_distance(cons1: Array, cons2: Array, jit_config: Dict): + """ + Calculate the distance between connections of two genomes. + Similar process as node_distance. + """ con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) max_cnt = jnp.maximum(con_cnt1, con_cnt2) @@ -64,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - cd = batch_homologous_connection_distance(fr, sr) - cd = jnp.where(jnp.isnan(cd), 0, cd) - homologous_distance = jnp.sum(cd * intersect_mask) + hcd = vmap(homologous_connection_distance)(fr, sr) + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ + 'compatibility_weight'] return jnp.where(max_cnt == 0, 0, val / max_cnt) -@vmap -def batch_homologous_node_distance(b_n1, b_n2): - return homologous_node_distance(b_n1, b_n2) - - -@vmap -def batch_homologous_connection_distance(b_c1, b_c2): - return homologous_connection_distance(b_c1, b_c2) - - @jit -def homologous_node_distance(n1, n2): +def homologous_node_distance(n1: Array, n2: Array): + """ + Calculate the distance between two homologous nodes. + """ d = 0 d += jnp.abs(n1[1] - n2[1]) # bias d += jnp.abs(n1[2] - n2[2]) # response d += n1[3] != n2[3] # activation - d += n1[4] != n2[4] + d += n1[4] != n2[4] # aggregation return d @jit -def homologous_connection_distance(c1, c2): +def homologous_connection_distance(c1: Array, c2: Array): + """ + Calculate the distance between two homologous connections. + """ d = 0 d += jnp.abs(c1[2] - c2[2]) # weight d += c1[3] != c2[3] # enable diff --git a/neat/genome/distance_.py b/neat/genome/distance_.py deleted file mode 100644 index 69e421e..0000000 --- a/neat/genome/distance_.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Calculate the distance between two genomes. -The calculation method is the same as the distance calculation in NEAT-python. -See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py -""" -from typing import Dict - -from jax import jit, vmap, Array -from jax import numpy as jnp - -from .utils import EMPTY_NODE, EMPTY_CON - - -@jit -def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array: - """ - Calculate the distance between two genomes. - args: - nodes1: Array(N, 5) - cons1: Array(C, 4) - nodes2: Array(N, 5) - cons2: Array(C, 4) - returns: - distance: Array(, ) - """ - nd = node_distance(nodes1, nodes2, jit_config) # node distance - cd = connection_distance(cons1, cons2, jit_config) # connection distance - return nd + cd - - -@jit -def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): - """ - Calculate the distance between nodes of two genomes. - """ - # statistics nodes count of two genomes - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) - - # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row - - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) - - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - - # calculate the distance of homologous nodes - hnd = vmap(homologous_node_distance)(fr, sr) - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) - - val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ - 'compatibility_weight'] - - return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division - - -@jit -def connection_distance(cons1: Array, cons2: Array, jit_config: Dict): - """ - Calculate the distance between connections of two genomes. - Similar process as node_distance. - """ - con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) - - cons = jnp.concatenate((cons1, cons2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row - - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - hcd = vmap(homologous_connection_distance)(fr, sr) - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) - - val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ - 'compatibility_weight'] - - return jnp.where(max_cnt == 0, 0, val / max_cnt) - - -@jit -def homologous_node_distance(n1: Array, n2: Array): - """ - Calculate the distance between two homologous nodes. - """ - d = 0 - d += jnp.abs(n1[1] - n2[1]) # bias - d += jnp.abs(n1[2] - n2[2]) # response - d += n1[3] != n2[3] # activation - d += n1[4] != n2[4] # aggregation - return d - - -@jit -def homologous_connection_distance(c1: Array, c2: Array): - """ - Calculate the distance between two homologous connections. - """ - d = 0 - d += jnp.abs(c1[2] - c2[2]) # weight - d += c1[3] != c2[3] # enable - return d diff --git a/neat/genome/forward.py b/neat/genome/forward.py index da150ae..9eeb7e5 100644 --- a/neat/genome/forward.py +++ b/neat/genome/forward.py @@ -2,47 +2,82 @@ import jax from jax import Array, numpy as jnp from jax import jit, vmap -from .aggregations import agg -from .activations import act from .utils import I_INT # TODO: enabled information doesn't influence forward. That is wrong! -@jit -def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array, - input_idx: Array, output_idx: Array) -> Array: - """ - jax forward for single input shaped (input_num, ) - nodes, connections are single genome +def create_forward(config): + def act(idx, z): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + res = jax.lax.switch(idx, config['activation_funcs'], z) + return res - :argument inputs: (input_num, ) - :argument input_idx: (input_num, ) - :argument output_idx: (output_num, ) - :argument cal_seqs: (N, ) - :argument nodes: (N, 5) - :argument connections: (2, N, N) + def agg(idx, z): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) - :return (output_num, ) - """ - N = nodes.shape[0] - ini_vals = jnp.full((N,), jnp.nan) - ini_vals = ini_vals.at[input_idx].set(inputs) + def all_nan(): + return 0. - def scan_body(carry, i): - def hit(): - ins = carry * connections[0, :, i] - z = agg(nodes[i, 4], ins) - z = z * nodes[i, 2] + nodes[i, 1] - z = act(nodes[i, 3], z) + def not_all_nan(): + return jax.lax.switch(idx, config['aggregation_funcs'], z) - new_vals = carry.at[i].set(z) - return new_vals + return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) - def miss(): - return carry + def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array: + """ + jax forward for single input shaped (input_num, ) + nodes, connections are a single genome - return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None + :argument inputs: (input_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) - vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) + :return (output_num, ) + """ - return vals[output_idx] + input_idx = config['input_idx'] + output_idx = config['output_idx'] + + N = nodes.shape[0] + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[input_idx].set(inputs) + + weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled + + def cond_fun(carry): + values, idx = carry + return (idx < N) & (cal_seqs[idx] != I_INT) + + def body_func(carry): + values, idx = carry + i = cal_seqs[idx] + + def hit(): + ins = values * weights[:, i] + z = agg(nodes[i, 4], ins) # z = agg(ins) + z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias + z = act(nodes[i, 3], z) # z = act(z) + + new_values = values.at[i].set(z) + return new_values + + def miss(): + return values + + # the val of input nodes is obtained by the task, not by calculation + values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) + return values, idx + 1 + + vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) + + return vals[output_idx] + + return forward diff --git a/neat/genome/genome_.py b/neat/genome/genome.py similarity index 86% rename from neat/genome/genome_.py rename to neat/genome/genome.py index 832de39..4f2d32b 100644 --- a/neat/genome/genome_.py +++ b/neat/genome/genome.py @@ -44,10 +44,13 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: pop_nodes[:, input_idx, 0] = input_idx pop_nodes[:, output_idx, 0] = output_idx - pop_nodes[:, output_idx, 1] = config['bias_init_mean'] - pop_nodes[:, output_idx, 2] = config['response_init_mean'] - pop_nodes[:, output_idx, 3] = config['activation_default'] - pop_nodes[:, output_idx, 4] = config['aggregation_default'] + # pop_nodes[:, output_idx, 1] = config['bias_init_mean'] + pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'], + size=(config['pop_size'], 1)) + pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'], + size=(config['pop_size'], 1)) + pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1)) + pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1)) grid_a, grid_b = np.meshgrid(input_idx, output_idx) grid_a, grid_b = grid_a.flatten(), grid_b.flatten() @@ -55,7 +58,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: p = config['num_inputs'] * config['num_outputs'] pop_cons[:, :p, 0] = grid_a pop_cons[:, :p, 1] = grid_b - pop_cons[:, :p, 2] = config['weight_init_mean'] + pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'], + size=(config['pop_size'], p)) pop_cons[:, :p, 3] = 1 return pop_nodes, pop_cons diff --git a/neat/genome/graph.py b/neat/genome/graph.py index 15592f2..6741cba 100644 --- a/neat/genome/graph.py +++ b/neat/genome/graph.py @@ -8,8 +8,7 @@ from jax import jit, vmap, Array from jax import numpy as jnp # from .configs import fetch_first, I_INT -from neat.genome.utils import fetch_first, I_INT -from .utils import unflatten_connections +from neat.genome.utils import fetch_first, I_INT, unflatten_connections @jit @@ -44,49 +43,32 @@ def topological_sort(nodes: Array, connections: Array) -> Array: topological_sort(nodes, connections) -> [0, 1, 2, 3] """ - connections_enable = connections[1, :, :] == 1 + connections_enable = connections[1, :, :] == 1 # forward function. thus use enable in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) res = jnp.full(in_degree.shape, I_INT) - idx = 0 - def scan_body(carry, _): + def cond_fun(carry): + res_, idx_, in_degree_ = carry + i = fetch_first(in_degree_ == 0.) + return i != I_INT + + def body_func(carry): res_, idx_, in_degree_ = carry i = fetch_first(in_degree_ == 0.) - def hit(): - # add to res and flag it is already in it - new_res = res_.at[idx_].set(i) - new_idx = idx_ + 1 - new_in_degree = in_degree_.at[i].set(-1) + # add to res and flag it is already in it + res_ = res_.at[idx_].set(i) + in_degree_ = in_degree_.at[i].set(-1) - # decrease in_degree of all its children - children = connections_enable[i, :] - new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree) - return new_res, new_idx, new_in_degree - - def miss(): - return res_, idx_, in_degree_ - - return jax.lax.cond(i == I_INT, miss, hit), None - - scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0]) - res, _, _ = scan_res + # decrease in_degree of all its children + children = connections_enable[i, :] + in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_) + return res_, idx_ + 1, in_degree_ + res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree)) return res -@jit -@vmap -def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: - """ - batch version of topological_sort - :param pop_nodes: - :param pop_connections: - :return: - """ - return topological_sort(pop_nodes, pop_connections) - - @jit def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: """ @@ -131,22 +113,26 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra check_cycles(nodes, connections, 1, 0) -> False """ - connections = unflatten_connections(nodes, connections) - connections_enable = ~jnp.isnan(connections[0, :, :]) - connections_enable = connections_enable.at[from_idx, to_idx].set(True) - nodes_visited = jnp.full(nodes.shape[0], False) - nodes_visited = nodes_visited.at[to_idx].set(True) - def scan_body(visited, _): - new_visited = jnp.dot(visited, connections_enable) - new_visited = jnp.logical_or(visited, new_visited) - return new_visited, None + visited = jnp.full(nodes.shape[0], False) + new_visited = visited.at[to_idx].set(True) - nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0]) + def cond_func(carry): + visited_, new_visited_ = carry + end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited + end_cond2 = new_visited_[from_idx] # the starting node has been visited + return jnp.logical_not(end_cond1 | end_cond2) - return nodes_visited[from_idx] + def body_func(carry): + _, visited_ = carry + new_visited_ = jnp.dot(visited_, connections_enable) + new_visited_ = jnp.logical_or(visited_, new_visited_) + return visited_, new_visited_ + + _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) + return visited[from_idx] if __name__ == '__main__': diff --git a/neat/genome/mutate.py b/neat/genome/mutate.py index 0638377..93c6c15 100644 --- a/neat/genome/mutate.py +++ b/neat/genome/mutate.py @@ -1,155 +1,64 @@ -from typing import Tuple +""" +Mutate a genome. +The calculation method is the same as the mutation operation in NEAT-python. +See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate +""" +from typing import Tuple, Dict from functools import partial import jax -import numpy as np from jax import numpy as jnp -from jax import jit, vmap, Array +from jax import jit, Array -from .utils import fetch_random, fetch_first, I_INT, unflatten_connections +from .utils import fetch_random, fetch_first, I_INT from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection from .graph import check_cycles -# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible. @jit -def mutate(rand_key: Array, - nodes: Array, - connections: Array, - new_node_key: int, - input_idx: Array, - output_idx: Array, - bias_mean: float = 0, - bias_std: float = 1, - bias_mutate_strength: float = 0.5, - bias_mutate_rate: float = 0.7, - bias_replace_rate: float = 0.1, - response_mean: float = 1., - response_std: float = 0., - response_mutate_strength: float = 0., - response_mutate_rate: float = 0., - response_replace_rate: float = 0., - weight_mean: float = 0., - weight_std: float = 1., - weight_mutate_strength: float = 0.5, - weight_mutate_rate: float = 0.7, - weight_replace_rate: float = 0.1, - act_default: int = 0, - act_list: Array = None, - act_replace_rate: float = 0.1, - agg_default: int = 0, - agg_list: Array = None, - agg_replace_rate: float = 0.1, - enabled_reverse_rate: float = 0.1, - add_node_rate: float = 0.2, - delete_node_rate: float = 0.2, - add_connection_rate: float = 0.4, - delete_connection_rate: float = 0.4, - ): +def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict): """ - :param output_idx: - :param input_idx: - :param agg_default: - :param act_default: :param rand_key: :param nodes: (N, 5) :param connections: (2, N, N) :param new_node_key: - :param bias_mean: - :param bias_std: - :param bias_mutate_strength: - :param bias_mutate_rate: - :param bias_replace_rate: - :param response_mean: - :param response_std: - :param response_mutate_strength: - :param response_mutate_rate: - :param response_replace_rate: - :param weight_mean: - :param weight_std: - :param weight_mutate_strength: - :param weight_mutate_rate: - :param weight_replace_rate: - :param act_list: - :param act_replace_rate: - :param agg_list: - :param agg_replace_rate: - :param enabled_reverse_rate: - :param add_node_rate: - :param delete_node_rate: - :param add_connection_rate: - :param delete_connection_rate: + :param jit_config: :return: """ - - def m_add_node(rk, n, c): - return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default) - - def m_add_connection(rk, n, c): - return mutate_add_connection(rk, n, c, input_idx, output_idx) - - def m_delete_node(rk, n, c): - return mutate_delete_node(rk, n, c, input_idx, output_idx) - - def m_delete_connection(rk, n, c): - return mutate_delete_connection(rk, n, c) - r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) + # structural mutations # mutate add node - aux_nodes, aux_connections = m_add_node(r1, nodes, connections) - nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes) - connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections) + r = rand(r1) + aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config) + nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections) # mutate add connection - aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) - nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) - connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) + r = rand(r2) + aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config) + nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections) # mutate delete node - aux_nodes, aux_connections = m_delete_node(r2, nodes, connections) - nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes) - connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections) + r = rand(r3) + aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config) + nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections) # mutate delete connection - aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections) - nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes) - connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections) + r = rand(r4) + aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections) + nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections) - nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength, - bias_mutate_rate, bias_replace_rate, response_mean, response_std, - response_mutate_strength, response_mutate_rate, response_replace_rate, - weight_mean, weight_std, weight_mutate_strength, - weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list, - agg_replace_rate, enabled_reverse_rate) + # value mutations + nodes, connections = mutate_values(rand_key, nodes, connections, jit_config) return nodes, connections -@jit -def mutate_values(rand_key: Array, - nodes: Array, - cons: Array, - bias_mean: float = 0, - bias_std: float = 1, - bias_mutate_strength: float = 0.5, - bias_mutate_rate: float = 0.7, - bias_replace_rate: float = 0.1, - response_mean: float = 1., - response_std: float = 0., - response_mutate_strength: float = 0., - response_mutate_rate: float = 0., - response_replace_rate: float = 0., - weight_mean: float = 0., - weight_std: float = 1., - weight_mutate_strength: float = 0.5, - weight_mutate_rate: float = 0.7, - weight_replace_rate: float = 0.1, - act_list: Array = None, - act_replace_rate: float = 0.1, - agg_list: Array = None, - agg_replace_rate: float = 0.1, - enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]: +def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: """ Mutate values of nodes and connections. @@ -157,56 +66,48 @@ def mutate_values(rand_key: Array, rand_key: A random key for generating random values. nodes: A 2D array representing nodes. cons: A 3D array representing connections. - bias_mean: Mean of the bias values. - bias_std: Standard deviation of the bias values. - bias_mutate_strength: Strength of the bias mutation. - bias_mutate_rate: Rate of the bias mutation. - bias_replace_rate: Rate of the bias replacement. - response_mean: Mean of the response values. - response_std: Standard deviation of the response values. - response_mutate_strength: Strength of the response mutation. - response_mutate_rate: Rate of the response mutation. - response_replace_rate: Rate of the response replacement. - weight_mean: Mean of the weight values. - weight_std: Standard deviation of the weight values. - weight_mutate_strength: Strength of the weight mutation. - weight_mutate_rate: Rate of the weight mutation. - weight_replace_rate: Rate of the weight replacement. - act_list: List of the activation function values. - act_replace_rate: Rate of the activation function replacement. - agg_list: List of the aggregation function values. - agg_replace_rate: Rate of the aggregation function replacement. - enabled_reverse_rate: Rate of reversing enabled state of connections. + jit_config: A dict containing configuration for jit-able functions. Returns: A tuple containing mutated nodes and connections. """ k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6) - bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std, - bias_mutate_strength, bias_mutate_rate, bias_replace_rate) - response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std, - response_mutate_strength, response_mutate_rate, response_replace_rate) - weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std, - weight_mutate_strength, weight_mutate_rate, weight_replace_rate) - act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate) - agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate) - # mutate enabled + # bias + bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'], + jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'], + jit_config['bias_replace_rate']) + + # response + response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'], + jit_config['response_init_std'], jit_config['response_mutate_power'], + jit_config['response_mutate_rate'], jit_config['response_replace_rate']) + + # weight + weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'], + jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'], + jit_config['weight_replace_rate']) + + # activation + act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'], + jit_config['activation_replace_rate']) + + # aggregation + agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'], + jit_config['aggregation_replace_rate']) + + # enabled r = jax.random.uniform(rand_key, cons[:, 3].shape) - enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3]) - enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan) + enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3]) + + # merge + nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new]) + cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new]) - nodes = nodes.at[:, 1].set(bias_new) - nodes = nodes.at[:, 2].set(response_new) - nodes = nodes.at[:, 3].set(act_new) - nodes = nodes.at[:, 4].set(agg_new) - cons = cons.at[:, 2].set(weight_new) - cons = cons.at[:, 3].set(enabled_new) return nodes, cons -@jit def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: """ @@ -227,19 +128,26 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa k1, k2, k3, rand_key = jax.random.split(rand_key, num=4) noise = jax.random.normal(k1, old_vals.shape) * mutate_strength replace = jax.random.normal(k2, old_vals.shape) * std + mean + r = jax.random.uniform(k3, old_vals.shape) + + # default new_vals = old_vals + + # r in [0, mutate_rate), mutate new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals) + + # r in [mutate_rate, mutate_rate + replace_rate), replace new_vals = jnp.where( - jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate), - replace, + (mutate_rate < r) & (r < mutate_rate + replace_rate), + replace + new_vals * 0.0, # in case of nan replace to values new_vals ) + new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) return new_vals -@jit def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: """ Mutate integer values (act, agg) of a given array. @@ -256,26 +164,20 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace k1, k2, rand_key = jax.random.split(rand_key, num=3) replace_val = jax.random.choice(k1, val_list, old_vals.shape) r = jax.random.uniform(k2, old_vals.shape) - new_vals = old_vals - new_vals = jnp.where(r < replace_rate, replace_val, new_vals) - new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) + new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values + return new_vals -@jit def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int, - default_bias: float = 0, default_response: float = 1, - default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: + jit_config: Dict) -> Tuple[Array, Array]: """ Randomly add a new node from splitting a connection. :param rand_key: :param new_node_key: :param nodes: :param cons: - :param default_bias: - :param default_response: - :param default_act: - :param default_agg: + :param jit_config: :return: """ # randomly choose a connection @@ -287,12 +189,13 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in def successful_add_node(): # disable the connection new_nodes, new_cons = nodes, cons + + # set enable to false new_cons = new_cons.at[idx, 3].set(False) # add a new node - new_nodes, new_cons = \ - add_node(new_nodes, new_cons, new_node_key, - bias=default_bias, response=default_response, act=default_act, agg=default_agg) + new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1, + act=jit_config['activation_default'], agg=jit_config['aggregation_default']) # add two new connections w = new_cons[idx, 2] @@ -306,59 +209,53 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in return nodes, cons -# TODO: Need we really need to delete a node? -@jit -def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, - input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: +# TODO: Do we really need to delete a node? +def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: """ Randomly delete a node. Input and output nodes are not allowed to be deleted. :param rand_key: :param nodes: :param cons: - :param input_keys: - :param output_keys: + :param jit_config: :return: """ # randomly choose a node - node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys, - allow_input_keys=False, allow_output_keys=False) + key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'], + allow_input_keys=False, allow_output_keys=False) def nothing(): return nodes, cons def successful_delete_node(): # delete the node - aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx) + aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx) # delete all connections - aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis], + aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], jnp.nan, aux_cons) return aux_nodes, aux_cons - nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) + nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node) return nodes, cons -@jit -def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, - input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: +def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: """ Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, cycles are not allowed. :param rand_key: :param nodes: :param cons: - :param input_keys: - :param output_keys: + :param jit_config: :return: """ # randomly choose two nodes k1, k2 = jax.random.split(rand_key, num=2) - i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, + i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'], allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, + o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'], allow_input_keys=False, allow_output_keys=True) con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) @@ -375,15 +272,14 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, return nodes, cons is_already_exist = con_idx != I_INT - unflattened = unflatten_connections(nodes, cons) - is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx) + + is_cycle = check_cycles(nodes, cons, from_idx, to_idx) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) return nodes, cons -@jit def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): """ Randomly delete a connection. @@ -406,7 +302,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): return nodes, cons -@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) def choice_node_key(rand_key: Array, nodes: Array, input_keys: Array, output_keys: Array, allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: @@ -435,7 +330,6 @@ def choice_node_key(rand_key: Array, nodes: Array, return key, idx -@jit def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: """ Randomly choose a connection key from the given connections. @@ -452,6 +346,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[A return i_key, o_key, idx -@jit def rand(rand_key): return jax.random.uniform(rand_key, ()) diff --git a/neat/genome/mutate_.py b/neat/genome/mutate_.py deleted file mode 100644 index 693fa31..0000000 --- a/neat/genome/mutate_.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -Mutate a genome. -The calculation method is the same as the mutation operation in NEAT-python. -See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate -""" -from typing import Tuple, Dict -from functools import partial - -import jax -from jax import numpy as jnp -from jax import jit, Array - -from .utils import fetch_random, fetch_first, I_INT -from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection -from .graph import check_cycles - - -@jit -def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict): - """ - :param rand_key: - :param nodes: (N, 5) - :param connections: (2, N, N) - :param new_node_key: - :param jit_config: - :return: - """ - r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) - - # structural mutations - # mutate add node - r = rand(r1) - aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config) - nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections) - - # mutate add connection - r = rand(r2) - aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config) - nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections) - - # mutate delete node - r = rand(r3) - aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config) - nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections) - - # mutate delete connection - r = rand(r4) - aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections) - nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections) - - # value mutations - nodes, connections = mutate_values(rand_key, nodes, connections, jit_config) - - return nodes, connections - - -def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Mutate values of nodes and connections. - - Args: - rand_key: A random key for generating random values. - nodes: A 2D array representing nodes. - cons: A 3D array representing connections. - jit_config: A dict containing configuration for jit-able functions. - - Returns: - A tuple containing mutated nodes and connections. - """ - - k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6) - - # bias - bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'], - jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'], - jit_config['bias_replace_rate']) - - # response - response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'], - jit_config['response_init_std'], jit_config['response_mutate_power'], - jit_config['response_mutate_rate'], jit_config['response_replace_rate']) - - # weight - weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'], - jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'], - jit_config['weight_replace_rate']) - - # activation - act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'], - jit_config['activation_replace_rate']) - - # aggregation - agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'], - jit_config['aggregation_replace_rate']) - - # enabled - r = jax.random.uniform(rand_key, cons[:, 3].shape) - enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3]) - - # merge - nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new]) - cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new]) - - return nodes, cons - - -def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, - mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: - """ - Mutate float values of a given array. - - Args: - rand_key: A random key for generating random values. - old_vals: A 1D array of float values to be mutated. - mean: Mean of the values. - std: Standard deviation of the values. - mutate_strength: Strength of the mutation. - mutate_rate: Rate of the mutation. - replace_rate: Rate of the replacement. - - Returns: - A mutated 1D array of float values. - """ - k1, k2, k3, rand_key = jax.random.split(rand_key, num=4) - noise = jax.random.normal(k1, old_vals.shape) * mutate_strength - replace = jax.random.normal(k2, old_vals.shape) * std + mean - - r = jax.random.uniform(k3, old_vals.shape) - - # default - new_vals = old_vals - - # r in [0, mutate_rate), mutate - new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals) - - # r in [mutate_rate, mutate_rate + replace_rate), replace - new_vals = jnp.where( - (mutate_rate < r) & (r < mutate_rate + replace_rate), - replace + new_vals * 0.0, # in case of nan replace to values - new_vals - ) - - new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) - return new_vals - - -def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: - """ - Mutate integer values (act, agg) of a given array. - - Args: - rand_key: A random key for generating random values. - old_vals: A 1D array of integer values to be mutated. - val_list: List of the integer values. - replace_rate: Rate of the replacement. - - Returns: - A mutated 1D array of integer values. - """ - k1, k2, rand_key = jax.random.split(rand_key, num=3) - replace_val = jax.random.choice(k1, val_list, old_vals.shape) - r = jax.random.uniform(k2, old_vals.shape) - new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values - - return new_vals - - -def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int, - jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly add a new node from splitting a connection. - :param rand_key: - :param new_node_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) - - def nothing(): # there is no connection to split - return nodes, cons - - def successful_add_node(): - # disable the connection - new_nodes, new_cons = nodes, cons - - # set enable to false - new_cons = new_cons.at[idx, 3].set(False) - - # add a new node - new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1, - act=jit_config['activation_default'], agg=jit_config['aggregation_default']) - - # add two new connections - w = new_cons[idx, 2] - new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True) - new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True) - return new_nodes, new_cons - - # if from_idx == I_INT, that means no connection exist, do nothing - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node) - - return nodes, cons - - -# TODO: Do we really need to delete a node? -@jit -def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly delete a node. Input and output nodes are not allowed to be deleted. - :param rand_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose a node - key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=False, allow_output_keys=False) - - def nothing(): - return nodes, cons - - def successful_delete_node(): - # delete the node - aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx) - - # delete all connections - aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], - jnp.nan, aux_cons) - - return aux_nodes, aux_cons - - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - - return nodes, cons - - -@jit -def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, - cycles are not allowed. - :param rand_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose two nodes - k1, k2 = jax.random.split(rand_key, num=2) - i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=False, allow_output_keys=True) - - con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) - - def successful(): - new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True) - return new_nodes, new_cons - - def already_exist(): - new_cons = cons.at[con_idx, 3].set(True) - return nodes, new_cons - - def cycle(): - return nodes, cons - - is_already_exist = con_idx != I_INT - - is_cycle = check_cycles(nodes, cons, from_idx, to_idx) - - choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) - return nodes, cons - - -@jit -def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): - """ - Randomly delete a connection. - :param rand_key: - :param nodes: - :param cons: - :return: - """ - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) - - def nothing(): - return nodes, cons - - def successfully_delete_connection(): - return delete_connection_by_idx(nodes, cons, idx) - - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - - return nodes, cons - - -def choice_node_key(rand_key: Array, nodes: Array, - input_keys: Array, output_keys: Array, - allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: - """ - Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. - :param rand_key: - :param nodes: - :param input_keys: - :param output_keys: - :param allow_input_keys: - :param allow_output_keys: - :return: return its key and position(idx) - """ - - node_keys = nodes[:, 0] - mask = ~jnp.isnan(node_keys) - - if not allow_input_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) - - if not allow_output_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) - - idx = fetch_random(rand_key, mask) - key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) - return key, idx - - -@jit -def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: - """ - Randomly choose a connection key from the given connections. - :param rand_key: - :param nodes: - :param cons: - :return: i_key, o_key, idx - """ - - idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) - i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) - o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) - - return i_key, o_key, idx - - -@jit -def rand(rand_key): - return jax.random.uniform(rand_key, ()) diff --git a/neat/genome/utils.py b/neat/genome/utils.py index 826cfae..9e1ef2f 100644 --- a/neat/genome/utils.py +++ b/neat/genome/utils.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Tuple import jax from jax import numpy as jnp, Array @@ -11,20 +10,18 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan) @jit -def unflatten_connections(nodes, cons): +def unflatten_connections(nodes: Array, cons: Array): """ transform the (C, 4) connections to (2, N, N) - this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled - connections to nan, that means we dont consider such connection when forward; - :param cons: - :param nodes: + :param nodes: (N, 5) + :param cons: (C, 4) :return: """ N = nodes.shape[0] node_keys = nodes[:, 0] i_keys, o_keys = cons[:, 0], cons[:, 1] - i_idxs = key_to_indices(i_keys, node_keys) - o_idxs = key_to_indices(o_keys, node_keys) + i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) + o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) res = jnp.full((2, N, N), jnp.nan) # Is interesting that jax use clip when attach data in array @@ -34,8 +31,6 @@ def unflatten_connections(nodes, cons): return res - -@partial(vmap, in_axes=(0, None)) def key_to_indices(key, keys): return fetch_first(key == keys) @@ -46,27 +41,12 @@ def fetch_first(mask, default=I_INT) -> Array: fetch the first True index :param mask: array of bool :param default: the default value if no element satisfying the condition - :return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT - example: - >>> a = jnp.array([1, 2, 3, 4, 5]) - >>> fetch_first(a > 3) - 3 - >>> fetch_first(a > 30) - I_INT + :return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value """ idx = jnp.argmax(mask) return jnp.where(mask[idx], idx, default) -@jit -def fetch_last(mask, default=I_INT) -> Array: - """ - similar to fetch_first, but fetch the last True index - """ - reversed_idx = fetch_first(mask[::-1], default) - return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1) - - @jit def fetch_random(rand_key, mask, default=I_INT) -> Array: """ @@ -78,27 +58,8 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) - @jit def argmin_with_mask(arr: Array, mask: Array) -> Array: masked_arr = jnp.where(mask, arr, jnp.inf) min_idx = jnp.argmin(masked_arr) - return min_idx - - -if __name__ == '__main__': - - a = jnp.array([1, 2, 3, 4, 5]) - print(fetch_first(a > 3)) - print(fetch_first(a > 30)) - - print(fetch_last(a > 3)) - print(fetch_last(a > 30)) - - rand_key = jax.random.PRNGKey(0) - - for t in [-1, 0, 1, 2, 3, 4, 5]: - for _ in range(10): - rand_key, _ = jax.random.split(rand_key) - print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2)) - print(t, fetch_random(rand_key, a > t)) + return min_idx \ No newline at end of file diff --git a/neat/genome/utils_.py b/neat/genome/utils_.py deleted file mode 100644 index 27ab91a..0000000 --- a/neat/genome/utils_.py +++ /dev/null @@ -1,102 +0,0 @@ -from functools import partial - -import jax -from jax import numpy as jnp, Array -from jax import jit, vmap - -I_INT = jnp.iinfo(jnp.int32).max # infinite int -EMPTY_NODE = jnp.full((1, 5), jnp.nan) -EMPTY_CON = jnp.full((1, 4), jnp.nan) - - -@jit -def unflatten_connections(nodes: Array, cons: Array): - """ - transform the (C, 4) connections to (2, N, N) - :param nodes: (N, 5) - :param cons: (C, 4) - :return: - """ - N = nodes.shape[0] - node_keys = nodes[:, 0] - i_keys, o_keys = cons[:, 0], cons[:, 1] - i_idxs = vmap(fetch_first, in_axes=(0, None))(i_keys, node_keys) - i_idxs = key_to_indices(i_keys, node_keys) - o_idxs = key_to_indices(o_keys, node_keys) - res = jnp.full((2, N, N), jnp.nan) - - # Is interesting that jax use clip when attach data in array - # however, it will do nothing set values in an array - res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) - res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) - - return res - - -@partial(vmap, in_axes=(0, None)) -def key_to_indices(key, keys): - return fetch_first(key == keys) - - -@jit -def fetch_first(mask, default=I_INT) -> Array: - """ - fetch the first True index - :param mask: array of bool - :param default: the default value if no element satisfying the condition - :return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT - example: - >>> a = jnp.array([1, 2, 3, 4, 5]) - >>> fetch_first(a > 3) - 3 - >>> fetch_first(a > 30) - I_INT - """ - idx = jnp.argmax(mask) - return jnp.where(mask[idx], idx, default) - - -@jit -def fetch_last(mask, default=I_INT) -> Array: - """ - similar to fetch_first, but fetch the last True index - """ - reversed_idx = fetch_first(mask[::-1], default) - return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1) - - -@jit -def fetch_random(rand_key, mask, default=I_INT) -> Array: - """ - similar to fetch_first, but fetch a random True index - """ - true_cnt = jnp.sum(mask) - cumsum = jnp.cumsum(mask) - target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) - mask = jnp.where(true_cnt == 0, False, cumsum >= target) - return fetch_first(mask, default) - - -@jit -def argmin_with_mask(arr: Array, mask: Array) -> Array: - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx - - -if __name__ == '__main__': - - a = jnp.array([1, 2, 3, 4, 5]) - print(fetch_first(a > 3)) - print(fetch_first(a > 30)) - - print(fetch_last(a > 3)) - print(fetch_last(a > 30)) - - rand_key = jax.random.PRNGKey(0) - - for t in [-1, 0, 1, 2, 3, 4, 5]: - for _ in range(10): - rand_key, _ = jax.random.split(rand_key) - print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2)) - print(t, fetch_random(rand_key, a > t)) diff --git a/neat/pipeline.py b/neat/pipeline.py index c3326d3..a99fd5b 100644 --- a/neat/pipeline.py +++ b/neat/pipeline.py @@ -1,158 +1,78 @@ -from typing import List, Union, Tuple, Callable -import time +from functools import partial -import jax import numpy as np +import jax -from .species import SpeciesController -from .genome import expand, expand_single +from configs.configer import Configer +from .genome.genome import initialize_genomes from .function_factory import FunctionFactory -from .population import * - class Pipeline: """ Neat algorithm pipeline. """ - def __init__(self, config, function_factory, seed=42): - self.time_dict = {} - self.function_factory = function_factory - + def __init__(self, config, function_factory=None, seed=42): self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) - self.config = config - self.N = config.basic.init_maximum_nodes - self.C = config.basic.init_maximum_connections - self.S = config.basic.init_maximum_species - self.expand_coe = config.basic.expands_coe - self.pop_size = config.neat.population.pop_size + self.config = config # global config + self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions + self.function_factory = function_factory or FunctionFactory(self.config) - self.species_controller = SpeciesController(config) - self.initialize_func = self.function_factory.create_initialize(self.N, self.C) - self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func() - - self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S) + self.symbols = { + 'P': self.config['pop_size'], + 'N': self.config['init_maximum_nodes'], + 'C': self.config['init_maximum_connections'], + 'S': self.config['init_maximum_species'], + } self.generation = 0 - self.generation_time_list = [] - self.species_controller.init_speciate(self.pop_nodes, self.pop_cons) - - self.best_fitness = float('-inf') self.best_genome = None - self.generation_timestamp = time.time() - self.evaluate_time = 0 + self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config) + def ask(self): """ - Create a forward function for the population. - :return: - Algorithm gives the population a forward function, then environment gives back the fitnesses. + Creates a function that receives a genome and returns a forward function. + There are 3 types of config['forward_way']: {'single', 'pop', 'common'} + + single: + Create pop_size number of forward functions. + Each function receive (batch_size, input_size) and returns (batch_size, output_size) + e.g. RL task + + pop: + Create a single forward function, which use only once calculation for the population. + The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size) + + common: + Special case of pop. The population has the same inputs. + The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size) + e.g. numerical regression; Hyper-NEAT + """ - return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons) + u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons) + pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons) - def tell(self, fitnesses): + if self.config['forward_way'] == 'single': + forward_funcs = [] + for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons): + func = lambda x: self.get_func('forward')(x, seq, nodes, cons) + forward_funcs.append(func) + return forward_funcs - self.generation += 1 + elif self.config['forward_way'] == 'pop': + func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) + return func - winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask( - fitnesses, - self.generation, - self.S, self.N, self.C) + elif self.config['forward_way'] == 'common': + func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) + return func - new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) - self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate( - self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask, - new_node_keys, - pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start) - - - self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \ - jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys]) - - self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation) - - self.expand() - - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): - for _ in range(self.config.neat.population.generation_limit): - forward_func = self.ask() - - tic = time.time() - fitnesses = fitness_func(forward_func) - self.evaluate_time += time.time() - tic - - assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" - - if analysis is not None: - if analysis == "default": - self.default_analysis(fitnesses) - else: - assert callable(analysis), f"What the fuck you passed in? A {analysis}?" - analysis(fitnesses) - - if max(fitnesses) >= self.config.neat.population.fitness_threshold: - print("Fitness limit reached!") - return self.best_genome - - self.tell(fitnesses) - print("Generation limit reached!") - return self.best_genome - - def expand(self): - """ - Expand the population if needed. - :return: - when the maximum node number of the population >= N - the population will expand - """ - pop_node_keys = self.pop_nodes[:, :, 0] - pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1) - max_node_size = np.max(pop_node_sizes) - if max_node_size >= self.N: - self.N = int(self.N * self.expand_coe) - # self.C = int(self.C * self.expand_coe) - print(f"node expand to {self.N}!") - self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) - - # don't forget to expand representation genome in species - for s in self.species_controller.species.values(): - s.representative = expand_single(*s.representative, self.N, self.C) - - - pop_con_keys = self.pop_cons[:, :, 0] - pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) - max_con_size = np.max(pop_node_sizes) - if max_con_size >= self.C: - # self.N = int(self.N * self.expand_coe) - self.C = int(self.C * self.expand_coe) - print(f"connections expand to {self.C}!") - self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) - - # don't forget to expand representation genome in species - for s in self.species_controller.species.values(): - s.representative = expand_single(*s.representative, self.N, self.C) - - self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S) - - - - def default_analysis(self, fitnesses): - max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) - species_sizes = [len(s.members) for s in self.species_controller.species.values()] - - new_timestamp = time.time() - cost_time = new_timestamp - self.generation_timestamp - self.generation_time_list.append(cost_time) - self.generation_timestamp = new_timestamp - - max_idx = np.argmax(fitnesses) - if fitnesses[max_idx] > self.best_fitness: - self.best_fitness = fitnesses[max_idx] - self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) - - print(f"Generation: {self.generation}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") + else: + raise NotImplementedError + def get_func(self, name): + return self.function_factory.get(name, self.symbols) diff --git a/neat/pipeline_.py b/neat/pipeline_.py deleted file mode 100644 index adafe00..0000000 --- a/neat/pipeline_.py +++ /dev/null @@ -1,27 +0,0 @@ -import jax - -from configs.configer import Configer -from .genome.genome_ import initialize_genomes - - -class Pipeline: - """ - Neat algorithm pipeline. - """ - - def __init__(self, config, seed=42): - self.randkey = jax.random.PRNGKey(seed) - - self.config = config # global config - self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions - self.N = self.config["init_maximum_nodes"] - self.C = self.config["init_maximum_connections"] - self.S = self.config["init_maximum_species"] - - self.generation = 0 - self.best_genome = None - - self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config) - - print(self.pop_nodes, self.pop_cons, sep='\n') - print(self.jit_config)