From 1f2327bbd672a12525f9331fac8b9a088237cad9 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 9 May 2023 01:58:00 +0800 Subject: [PATCH] clean imports and delete "create_XXX_functions" --- algorithms/neat/function_factory.py | 6 +- algorithms/neat/genome/__init__.py | 8 +- algorithms/neat/genome/crossover.py | 37 +------- algorithms/neat/genome/distance.py | 130 ---------------------------- algorithms/neat/genome/genome.py | 16 ---- algorithms/neat/genome/mutate.py | 105 +++------------------- algorithms/neat/pipeline.py | 4 +- 7 files changed, 20 insertions(+), 286 deletions(-) diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index 20bf671..d29d0c0 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -6,11 +6,7 @@ from functools import partial import numpy as np from jax import jit, vmap -from .genome import act_name2key, agg_name2key -from .genome.genome import initialize_genomes -from .genome.mutate import mutate -from .genome.distance import distance -from .genome.crossover import crossover +from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover class FunctionFactory: diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index fdd306b..f276f13 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,7 +1,7 @@ -from .genome import create_initialize_function, expand, expand_single, pop_analysis -from .distance import create_distance_function -from .mutate import create_mutate_function +from .genome import expand, expand_single, pop_analysis, initialize_genomes from .forward import create_forward_function -from .crossover import create_crossover_function from .activations import act_name2key from .aggregations import agg_name2key +from .crossover import crossover +from .mutate import mutate +from .distance import distance diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index d5eaf2c..39bc0ce 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -8,38 +8,7 @@ from jax import numpy as jnp from .utils import flatten_connections, unflatten_connections -def create_crossover_function(N, config, batch: bool, debug: bool = False): - if batch: - pop_size = config.neat.population.pop_size - randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) - nodes1_lower = jnp.zeros((pop_size, N, 5)) - connections1_lower = jnp.zeros((pop_size, 2, N, N)) - nodes2_lower = jnp.zeros((pop_size, N, 5)) - connections2_lower = jnp.zeros((pop_size, 2, N, N)) - - res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() - if debug: - return lambda *args: res_func(*args) - else: - return res_func - - else: - randkey_lower = jnp.zeros((2,), dtype=jnp.uint32) - nodes1_lower = jnp.zeros((N, 5)) - connections1_lower = jnp.zeros((2, N, N)) - nodes2_lower = jnp.zeros((N, 5)) - connections2_lower = jnp.zeros((2, N, N)) - - res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() - if debug: - return lambda *args: res_func(*args) - else: - return res_func - - -# @jit +@jit def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ -> Tuple[Array, Array]: """ @@ -70,7 +39,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, return new_nodes, new_cons -# @partial(jit, static_argnames=['gene_type']) +@partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ make ar2 align with ar1. @@ -97,7 +66,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: return refactor_ar2 -# @jit +@jit def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: """ crossover two genes diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 4a78b3f..8a5db3e 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -1,139 +1,9 @@ from jax import jit, vmap, Array from jax import numpy as jnp -import numpy as np -from numpy.typing import NDArray from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON -def create_distance_function(N, config, type: str, debug: bool = False): - """ - :param N: - :param config: - :param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation - :param debug: - :return: - """ - disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient - compatibility_coe = config.neat.genome.compatibility_weight_coefficient - - def distance_with_args(nodes1, connections1, nodes2, connections2): - return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) - - if type == 'o2o': - nodes1_lower = jnp.zeros((N, 5)) - connections1_lower = jnp.zeros((2, N, N)) - nodes2_lower = jnp.zeros((N, 5)) - connections2_lower = jnp.zeros((2, N, N)) - - res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() - if debug: - return lambda *args: res_func(*args) # for debug - else: - return res_func - - elif type == 'o2m': - vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) - pop_size = config.neat.population.pop_size - nodes1_lower = jnp.zeros((N, 5)) - connections1_lower = jnp.zeros((2, N, N)) - nodes2_lower = jnp.zeros((pop_size, N, 5)) - connections2_lower = jnp.zeros((pop_size, 2, N, N)) - res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() - if debug: - return lambda *args: res_func(*args) # for debug - else: - return res_func - - else: - raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') - - -def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray, - connection2: NDArray, disjoint_coe: float = 1., compatibility_coe: float = 0.5): - """ - use in o2o distance. - o2o can't use vmap, numpy should be faster than jax function - :param nodes1: - :param connection1: - :param nodes2: - :param connection2: - :param disjoint_coe: - :param compatibility_coe: - :return: - """ - - def analysis(nodes, connections): - nodes_dict = {} - idx2key = {} - for i, node in enumerate(nodes): - if np.isnan(node[0]): - continue - key = int(node[0]) - nodes_dict[key] = (node[1], node[2], node[3], node[4]) - idx2key[i] = key - - connections_dict = {} - for i in range(connections.shape[1]): - for j in range(connections.shape[2]): - if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]): - continue - key = (idx2key[i], idx2key[j]) - - weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None - enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None - connections_dict[key] = (weight, enabled) - - return nodes_dict, connections_dict - - nodes1, connections1 = analysis(nodes1, connection1) - nodes2, connections2 = analysis(nodes2, connection2) - - nd = 0.0 - if nodes1 or nodes2: # otherwise, both are empty - disjoint_nodes = 0 - for k2 in nodes2: - if k2 not in nodes1: - disjoint_nodes += 1 - - for k1, n1 in nodes1.items(): - n2 = nodes2.get(k1) - if n2 is None: - disjoint_nodes += 1 - else: - if np.isnan(n1[0]): # n1[1] is nan means input nodes - continue - d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1]) - d += 1 if n1[2] != n2[2] else 0 - d += 1 if n1[3] != n2[3] else 0 - nd += d - - max_nodes = max(len(nodes1), len(nodes2)) - nd = (compatibility_coe * nd + disjoint_coe * disjoint_nodes) / max_nodes - - cd = 0.0 - if connections1 or connections2: - disjoint_connections = 0 - for k2 in connections2: - if k2 not in connections1: - disjoint_connections += 1 - - for k1, c1 in connections1.items(): - c2 = connections2.get(k1) - if c2 is None: - disjoint_connections += 1 - else: - # Homologous genes compute their own distance value. - d = abs(c1[0] - c2[0]) - d += 1 if c1[1] != c2[1] else 0 - cd += d - max_conn = max(len(connections1), len(connections2)) - cd = (compatibility_coe * cd + disjoint_coe * disjoint_connections) / max_conn - - return nd + cd - - @jit def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., compatibility_coe: float = 0.5) -> Array: diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 591f022..535b352 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -22,27 +22,11 @@ from jax import numpy as jnp from jax import jit from jax import Array -from .activations import act_name2key -from .aggregations import agg_name2key from .utils import fetch_first EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) -def create_initialize_function(config): - pop_size = config.neat.population.pop_size - N = config.basic.init_maximum_nodes - num_inputs = config.basic.num_inputs - num_outputs = config.basic.num_outputs - default_bias = config.neat.gene.bias.init_mean - default_response = config.neat.gene.response.init_mean - default_act = act_name2key[config.neat.gene.activation.default] - default_agg = agg_name2key[config.neat.gene.aggregation.default] - default_weight = config.neat.gene.weight.init_mean - return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response, - default_act, default_agg, default_weight) - - def initialize_genomes(pop_size: int, N: int, num_inputs: int, diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index af08637..ae669d9 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -13,100 +13,7 @@ from .activations import act_name2key from .aggregations import agg_name2key -def create_mutate_function(N, config, batch: bool, debug: bool = False): - """ - create mutate function for different situations - :param N: - :param config: - :param batch: mutate for population or not - :param debug: - :return: - """ - num_inputs = config.basic.num_inputs - num_outputs = config.basic.num_outputs - input_idx = np.arange(num_inputs) - output_idx = np.arange(num_inputs, num_inputs + num_outputs) - - bias = config.neat.gene.bias - bias_default = bias.init_mean - bias_mean = bias.init_mean - bias_std = bias.init_stdev - bias_mutate_strength = bias.mutate_power - bias_mutate_rate = bias.mutate_rate - bias_replace_rate = bias.replace_rate - - response = config.neat.gene.response - response_default = response.init_mean - response_mean = response.init_mean - response_std = response.init_stdev - response_mutate_strength = response.mutate_power - response_mutate_rate = response.mutate_rate - response_replace_rate = response.replace_rate - - weight = config.neat.gene.weight - weight_mean = weight.init_mean - weight_std = weight.init_stdev - weight_mutate_strength = weight.mutate_power - weight_mutate_rate = weight.mutate_rate - weight_replace_rate = weight.replace_rate - - activation = config.neat.gene.activation - act_default = act_name2key[activation.default] - act_list = np.array([act_name2key[name] for name in activation.options]) - act_replace_rate = activation.mutate_rate - - aggregation = config.neat.gene.aggregation - agg_default = agg_name2key[aggregation.default] - agg_list = np.array([agg_name2key[name] for name in aggregation.options]) - agg_replace_rate = aggregation.mutate_rate - - enabled = config.neat.gene.enabled - enabled_reverse_rate = enabled.mutate_rate - - genome = config.neat.genome - add_node_rate = genome.node_add_prob - delete_node_rate = genome.node_delete_prob - add_connection_rate = genome.conn_add_prob - delete_connection_rate = genome.conn_delete_prob - single_structure_mutate = genome.single_structural_mutation - - def mutate_with_args(rand_key, nodes, connections, new_node_key): - return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx, - bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, - bias_replace_rate, response_default, response_mean, response_std, - response_mutate_strength, response_mutate_rate, response_replace_rate, - weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, - weight_replace_rate, act_default, act_list, act_replace_rate, - agg_default, agg_list, agg_replace_rate, enabled_reverse_rate, - add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, - single_structure_mutate) - - if not batch: - rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32) - nodes_lower = jnp.zeros((N, 5)) - connections_lower = jnp.zeros((2, N, N)) - new_node_key_lower = jnp.zeros((), dtype=jnp.int32) - res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower, - connections_lower, new_node_key_lower).compile() - if debug: - return lambda *args: res_func(*args) - else: - return res_func - else: - pop_size = config.neat.population.pop_size - rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) - nodes_lower = jnp.zeros((pop_size, N, 5)) - connections_lower = jnp.zeros((pop_size, 2, N, N)) - new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32) - - batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower, - connections_lower, new_node_key_lower).compile() - if debug: - return lambda *args: batched_mutate_func(*args) - else: - return batched_mutate_func - - +@partial(jit, static_argnames=('single_structure_mutate',)) def mutate(rand_key: Array, nodes: Array, connections: Array, @@ -243,6 +150,7 @@ def mutate(rand_key: Array, return nodes, connections +@jit def mutate_values(rand_key: Array, nodes: Array, connections: Array, @@ -323,6 +231,7 @@ def mutate_values(rand_key: Array, return nodes, connections +@jit def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: """ @@ -355,6 +264,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa return new_vals +@jit def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: """ Mutate integer values (act, agg) of a given array. @@ -377,6 +287,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace return new_vals +@jit def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, default_bias: float = 0, default_response: float = 1, default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: @@ -423,6 +334,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection return nodes, connections +@jit def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -456,6 +368,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, return nodes, connections +@jit def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -494,6 +407,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, return nodes, connections +@jit def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): """ Randomly delete a connection. @@ -516,6 +430,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): return nodes, connections +@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) def choice_node_key(rand_key: Array, nodes: Array, input_keys: Array, output_keys: Array, allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: @@ -544,6 +459,7 @@ def choice_node_key(rand_key: Array, nodes: Array, return key, idx +@jit def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: """ Randomly choose a connection key from the given connections. @@ -571,5 +487,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T return from_key, to_key, from_idx, to_idx +@jit def rand(rand_key): return jax.random.uniform(rand_key, ()) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 3f2c2bc..a1160c6 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -5,9 +5,7 @@ import jax import numpy as np from .species import SpeciesController -from .genome import expand, expand_single -from .genome import create_initialize_function, create_mutate_function, create_forward_function, \ - create_distance_function, create_crossover_function +from .genome import expand, expand_single, create_forward_function from .function_factory import FunctionFactory