diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py index 16f0d52..48df72f 100644 --- a/algorithms/neat/genome/activations.py +++ b/algorithms/neat/genome/activations.py @@ -134,5 +134,3 @@ def act(idx, z): # change idx from float to int return jax.lax.switch(idx, ACT_TOTAL_LIST, z) - -vectorized_act = jax.vmap(act, in_axes=(0, 0)) diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 2cf8cdb..898d867 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -48,7 +48,7 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask) if gene_type == 'node': - node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) + node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) else: # connection node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) @@ -64,7 +64,17 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene -@partial(vmap, in_axes=(0, 0)) +@vmap +def batch_homologous_node_distance(b_n1, b_n2): + return homologous_node_distance(b_n1, b_n2) + + +@vmap +def batch_homologous_connection_distance(b_c1, b_c2): + return homologous_connection_distance(b_c1, b_c2) + + +@jit def homologous_node_distance(n1, n2): d = 0 d += jnp.abs(n1[1] - n2[1]) # bias @@ -74,7 +84,7 @@ def homologous_node_distance(n1, n2): return d -@partial(vmap, in_axes=(0, 0)) +@jit def homologous_connection_distance(c1, c2): d = 0 d += jnp.abs(c1[2] - c2[2]) # weight diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index 8f753dd..44752a3 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -95,11 +95,11 @@ def topological_sort_debug(nodes: Array, connections: Array) -> Array: @vmap -def batch_topological_sort(nodes: Array, connections: Array) -> Array: +def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: """ batch version of topological_sort - :param nodes: - :param connections: + :param pop_nodes: + :param pop_connections: :return: """ return topological_sort(nodes, connections) @@ -175,17 +175,17 @@ if __name__ == '__main__': ]) connections = jnp.array([ [ - [0, 0, 1, 0, jnp.nan], - [0, 0, 1, 1, jnp.nan], - [0, 0, 0, 1, jnp.nan], - [0, 0, 0, 0, jnp.nan], + [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, 1, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] ], [ - [0, 0, 1, 0, jnp.nan], - [0, 0, 1, 1, jnp.nan], - [0, 0, 0, 1, jnp.nan], - [0, 0, 0, 0, jnp.nan], + [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, 1, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] ] ] diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 3334a2d..12227c9 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -386,18 +386,30 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection # randomly choose a connection from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) - # disable the connection - connections = connections.at[1, from_idx, to_idx].set(False) + def nothing(): + return nodes, connections - # add a new node - nodes, connections = add_node(new_node_key, nodes, connections, - bias=default_bias, response=default_response, act=default_act, agg=default_agg) - new_idx = fetch_first(nodes[:, 0] == new_node_key) + def successful_add_node(): + # disable the connection + new_nodes, new_connections = nodes, connections + new_connections = new_connections.at[1, from_idx, to_idx].set(False) - # add two new connections - weight = connections[0, from_idx, to_idx] - nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True) - nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True) + # add a new node + new_nodes, new_connections = \ + add_node(new_node_key, new_nodes, new_connections, + bias=default_bias, response=default_response, act=default_act, agg=default_agg) + new_idx = fetch_first(new_nodes[:, 0] == new_node_key) + + # add two new connections + weight = new_connections[0, from_idx, to_idx] + new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, + new_nodes, new_connections, weight=0, enabled=True) + new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, + new_nodes, new_connections, weight=weight, enabled=True) + return new_nodes, new_connections + + # if from_idx == I_INT, that means no connection exist, do nothing + nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node) return nodes, connections @@ -482,7 +494,15 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): """ # randomly choose a connection from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) - nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections) + + def nothing(): + return nodes, connections + + def successfully_delete_connection(): + return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + + nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection) + return nodes, connections @@ -530,6 +550,10 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T col = connection[0, from_idx, :] to_idx = fetch_random(k2, ~jnp.isnan(col)) from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0] + + from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan) + to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan) + return from_key, to_key, from_idx, to_idx diff --git a/algorithms/neat/genome/numpy/__init__.py b/algorithms/neat/genome/numpy/__init__.py new file mode 100644 index 0000000..bb905b5 --- /dev/null +++ b/algorithms/neat/genome/numpy/__init__.py @@ -0,0 +1,5 @@ +from .genome import create_initialize_function, expand, expand_single +from .distance import distance +from .mutate import create_mutate_function +from .forward import create_forward_function +from .crossover import batch_crossover diff --git a/algorithms/neat/genome/numpy/activations.py b/algorithms/neat/genome/numpy/activations.py new file mode 100644 index 0000000..308faa0 --- /dev/null +++ b/algorithms/neat/genome/numpy/activations.py @@ -0,0 +1,113 @@ +import numpy as np + + +def sigmoid_act(z): + z = np.clip(z * 5, -60, 60) + return 1 / (1 + np.exp(-z)) + + +def tanh_act(z): + z = np.clip(z * 2.5, -60, 60) + return np.tanh(z) + + +def sin_act(z): + z = np.clip(z * 5, -60, 60) + return np.sin(z) + + +def gauss_act(z): + z = np.clip(z, -3.4, 3.4) + return np.exp(-5 * z ** 2) + + +def relu_act(z): + return np.maximum(z, 0) + + +def elu_act(z): + return np.where(z > 0, z, np.exp(z) - 1) + + +def lelu_act(z): + leaky = 0.005 + return np.where(z > 0, z, leaky * z) + + +def selu_act(z): + lam = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + return np.where(z > 0, lam * z, lam * alpha * (np.exp(z) - 1)) + + +def softplus_act(z): + z = np.clip(z * 5, -60, 60) + return 0.2 * np.log(1 + np.exp(z)) + + +def identity_act(z): + return z + + +def clamped_act(z): + return np.clip(z, -1, 1) + + +def inv_act(z): + return 1 / z + + +def log_act(z): + z = np.maximum(z, 1e-7) + return np.log(z) + + +def exp_act(z): + z = np.clip(z, -60, 60) + return np.exp(z) + + +def abs_act(z): + return np.abs(z) + + +def hat_act(z): + return np.maximum(0, 1 - np.abs(z)) + + +def square_act(z): + return z ** 2 + + +def cube_act(z): + return z ** 3 + + +ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act, + identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act] + +act_name2key = { + 'sigmoid': 0, + 'tanh': 1, + 'sin': 2, + 'gauss': 3, + 'relu': 4, + 'elu': 5, + 'lelu': 6, + 'selu': 7, + 'softplus': 8, + 'identity': 9, + 'clamped': 10, + 'inv': 11, + 'log': 12, + 'exp': 13, + 'abs': 14, + 'hat': 15, + 'square': 16, + 'cube': 17, +} + + +def act(idx, z): + idx = np.asarray(idx, dtype=np.int32) + return ACT_TOTAL_LIST[idx](z) diff --git a/algorithms/neat/genome/numpy/aggregations.py b/algorithms/neat/genome/numpy/aggregations.py new file mode 100644 index 0000000..b30daa1 --- /dev/null +++ b/algorithms/neat/genome/numpy/aggregations.py @@ -0,0 +1,86 @@ +""" +aggregations, two special case need to consider: +1. extra 0s +2. full of 0s +""" +import numpy as np + + +def sum_agg(z): + z = np.where(np.isnan(z), 0, z) + return np.sum(z, axis=0) + + +def product_agg(z): + z = np.where(np.isnan(z), 1, z) + return np.prod(z, axis=0) + + +def max_agg(z): + z = np.where(np.isnan(z), -np.inf, z) + return np.max(z, axis=0) + + +def min_agg(z): + z = np.where(np.isnan(z), np.inf, z) + return np.min(z, axis=0) + + +def maxabs_agg(z): + z = np.where(np.isnan(z), 0, z) + abs_z = np.abs(z) + max_abs_index = np.argmax(abs_z) + return z[max_abs_index] + + +def median_agg(z): + non_zero_mask = ~np.isnan(z) + n = np.sum(non_zero_mask, axis=0) + + z = np.where(np.isnan(z), np.inf, z) + sorted_valid_values = np.sort(z) + + if n % 2 == 0: + return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2 + else: + return sorted_valid_values[n // 2] + + +def mean_agg(z): + non_zero_mask = ~np.isnan(z) + valid_values_sum = sum_agg(z) + valid_values_count = np.sum(non_zero_mask, axis=0) + mean_without_zeros = valid_values_sum / valid_values_count + return mean_without_zeros + + +AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg] + +agg_name2key = { + 'sum': 0, + 'product': 1, + 'max': 2, + 'min': 3, + 'maxabs': 4, + 'median': 5, + 'mean': 6, +} + + +def agg(idx, z): + idx = np.asarray(idx, dtype=np.int32) + + if np.all(z == 0.): + return 0 + else: + return AGG_TOTAL_LIST[idx](z) + + +if __name__ == '__main__': + array = np.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=np.float32) + for names in agg_name2key.keys(): + print(names, agg(agg_name2key[names], array)) + + array2 = np.asarray([0, 0, 0, 0], dtype=np.float32) + for names in agg_name2key.keys(): + print(names, agg(agg_name2key[names], array2)) diff --git a/algorithms/neat/genome/numpy/crossover.py b/algorithms/neat/genome/numpy/crossover.py new file mode 100644 index 0000000..ac5be02 --- /dev/null +++ b/algorithms/neat/genome/numpy/crossover.py @@ -0,0 +1,90 @@ +from typing import Tuple + +import numpy as np +from numpy.typing import NDArray + +from .utils import flatten_connections, unflatten_connections + + +def batch_crossover(batch_nodes1: NDArray, batch_connections1: NDArray, batch_nodes2: NDArray, + batch_connections2: NDArray) -> Tuple[NDArray, NDArray]: + """ + crossover a batch of genomes + :param batch_nodes1: + :param batch_connections1: + :param batch_nodes2: + :param batch_connections2: + :return: + """ + res_nodes, res_cons = [], [] + for (n1, c1, n2, c2) in zip(batch_nodes1, batch_connections1, batch_nodes2, batch_connections2): + new_nodes, new_cons = crossover(n1, c1, n2, c2) + res_nodes.append(new_nodes) + res_cons.append(new_cons) + return np.stack(res_nodes, axis=0), np.stack(res_cons, axis=0) + + +def crossover(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) \ + -> Tuple[NDArray, NDArray]: + """ + use genome1 and genome2 to generate a new genome + notice that genome1 should have higher fitness than genome2 (genome1 is winner!) + :param nodes1: + :param connections1: + :param nodes2: + :param connections2: + :return: + """ + + # crossover nodes + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + nodes2 = align_array(keys1, keys2, nodes2, 'node') + new_nodes = np.where(np.isnan(nodes1) | np.isnan(nodes2), nodes1, crossover_gene(nodes1, nodes2)) + + # crossover connections + cons1 = flatten_connections(keys1, connections1) + cons2 = flatten_connections(keys2, connections2) + con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] + cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') + new_cons = np.where(np.isnan(cons1) | np.isnan(cons2), cons1, crossover_gene(cons1, cons2)) + new_cons = unflatten_connections(len(keys1), new_cons) + + return new_nodes, new_cons + + +def align_array(seq1: NDArray, seq2: NDArray, ar2: NDArray, gene_type: str) -> NDArray: + """ + make ar2 align with ar1. + :param seq1: + :param seq2: + :param ar2: + :param gene_type: + :return: + align means to intersect part of ar2 will be at the same position as ar1, + non-intersect part of ar2 will be set to Nan + """ + seq1, seq2 = seq1[:, np.newaxis], seq2[np.newaxis, :] + mask = (seq1 == seq2) & (~np.isnan(seq1)) + + if gene_type == 'connection': + mask = np.all(mask, axis=2) + + intersect_mask = mask.any(axis=1) + idx = np.arange(0, len(seq1)) + idx_fixed = np.dot(mask, idx) + + refactor_ar2 = np.where(intersect_mask[:, np.newaxis], ar2[idx_fixed], np.nan) + + return refactor_ar2 + + +def crossover_gene(g1: NDArray, g2: NDArray) -> NDArray: + """ + crossover two genes + :param g1: + :param g2: + :return: + only gene with the same key will be crossover, thus don't need to consider change key + """ + r = np.random.rand() + return np.where(r > 0.5, g1, g2) diff --git a/algorithms/neat/genome/numpy/distance.py b/algorithms/neat/genome/numpy/distance.py new file mode 100644 index 0000000..b5d313a --- /dev/null +++ b/algorithms/neat/genome/numpy/distance.py @@ -0,0 +1,94 @@ +from functools import partial + +import numpy as np +from numpy.typing import NDArray + +from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis + + +def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) -> NDArray: + """ + Calculate the distance between two genomes. + nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg] + connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] + """ + + node_distance = gene_distance(nodes1, nodes2, 'node') + + # refactor connections + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + cons1 = flatten_connections(keys1, connections1) + cons2 = flatten_connections(keys2, connections2) + + connection_distance = gene_distance(cons1, cons2, 'connection') + return node_distance + connection_distance + + +def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): + if gene_type == 'node': + keys1, keys2 = ar1[:, :1], ar2[:, :1] + else: # connection + keys1, keys2 = ar1[:, :2], ar2[:, :2] + + n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) + nodes = np.concatenate((ar1, ar2), axis=0) + sorted_nodes = nodes[n_sorted_indices] + + if gene_type == 'node': + node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1) + else: + node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1) + + n_intersect_mask = n_intersect_mask & node_exist_mask + n_union_mask = n_union_mask & node_exist_mask + + fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] + + non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask) + if gene_type == 'node': + node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) + else: # connection + node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) + + node_distance = np.where(np.isnan(node_distance), 0, node_distance) + homologous_distance = np.sum(node_distance * n_intersect_mask[:-1]) + + gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1)) + gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1)) + max_cnt = np.maximum(gene_cnt1, gene_cnt2) + + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + + return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene + + +def batch_homologous_node_distance(b_n1, b_n2): + res = [] + for n1, n2 in zip(b_n1, b_n2): + d = homologous_node_distance(n1, n2) + res.append(d) + return np.stack(res, axis=0) + + +def batch_homologous_connection_distance(b_c1, b_c2): + res = [] + for c1, c2 in zip(b_c1, b_c2): + d = homologous_connection_distance(c1, c2) + res.append(d) + return np.stack(res, axis=0) + + +def homologous_node_distance(n1, n2): + d = 0 + d += np.abs(n1[1] - n2[1]) # bias + d += np.abs(n1[2] - n2[2]) # response + d += n1[3] != n2[3] # activation + d += n1[4] != n2[4] + return d + + +def homologous_connection_distance(c1, c2): + d = 0 + d += np.abs(c1[2] - c2[2]) # weight + d += c1[3] != c2[3] # enable + return d diff --git a/algorithms/neat/genome/numpy/forward.py b/algorithms/neat/genome/numpy/forward.py new file mode 100644 index 0000000..b16b21d --- /dev/null +++ b/algorithms/neat/genome/numpy/forward.py @@ -0,0 +1,151 @@ +from functools import partial + +import numpy as np +from numpy.typing import NDArray + +from .aggregations import agg +from .activations import act +from .graph import topological_sort, batch_topological_sort +from .utils import I_INT + + +def create_forward_function(nodes: NDArray, connections: NDArray, + N: int, input_idx: NDArray, output_idx: NDArray, batch: bool): + """ + create forward function for different situations + + :param nodes: shape (N, 5) or (pop_size, N, 5) + :param connections: shape (2, N, N) or (pop_size, 2, N, N) + :param N: + :param input_idx: + :param output_idx: + :param batch: using batch or not + :param debug: debug mode + :return: + """ + + if nodes.ndim == 2: # single genome + cal_seqs = topological_sort(nodes, connections) + if not batch: + return lambda inputs: forward_single(inputs, N, input_idx, output_idx, + cal_seqs, nodes, connections) + else: + return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx, + cal_seqs, nodes, connections) + elif nodes.ndim == 3: # pop genome + pop_cal_seqs = batch_topological_sort(nodes, connections) + if not batch: + return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx, + pop_cal_seqs, nodes, connections) + else: + return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx, + pop_cal_seqs, nodes, connections) + else: + raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") + + +def forward_single(inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray, + cal_seqs: NDArray, nodes: NDArray, connections: NDArray) -> NDArray: + """ + jax forward for single input shaped (input_num, ) + nodes, connections are single genome + + :argument inputs: (input_num, ) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (output_num, ) + """ + ini_vals = np.full((N,), np.nan) + ini_vals[input_idx] = inputs + + for i in cal_seqs: + if i in input_idx: + continue + if i == I_INT: + break + ins = ini_vals * connections[0, :, i] + z = agg(nodes[i, 4], ins) + z = z * nodes[i, 2] + nodes[i, 1] + z = act(nodes[i, 3], z) + + # for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals + ini_vals[i] = z + + + return ini_vals[output_idx] + + +def forward_batch(batch_inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray, + cal_seqs: NDArray, nodes: NDArray, connections: NDArray) -> NDArray: + """ + jax forward for batch_inputs shaped (batch_size, input_num) + nodes, connections are single genome + + :argument batch_inputs: (batch_size, input_num) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (batch_size, output_num) + """ + res = [] + for inputs in batch_inputs: + out = forward_single(inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) + res.append(out) + return np.stack(res, axis=0) + + +def pop_forward_single(inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray, + pop_cal_seqs: NDArray, pop_nodes: NDArray, pop_connections: NDArray) -> NDArray: + """ + jax forward for single input shaped (input_num, ) + pop_nodes, pop_connections are population of genomes + + :argument inputs: (input_num, ) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument pop_cal_seqs: (pop_size, N) + :argument pop_nodes: (pop_size, N, 5) + :argument pop_connections: (pop_size, 2, N, N) + + :return (pop_size, output_num) + """ + res = [] + for cal_seqs, nodes, connections in zip(pop_cal_seqs, pop_nodes, pop_connections): + out = forward_single(inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) + res.append(out) + + return np.stack(res, axis=0) + + +def pop_forward_batch(batch_inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray, + pop_cal_seqs: NDArray, pop_nodes: NDArray, pop_connections: NDArray) -> NDArray: + """ + jax forward for batch input shaped (batch, input_num) + pop_nodes, pop_connections are population of genomes + + :argument batch_inputs: (batch_size, input_num) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument pop_cal_seqs: (pop_size, N) + :argument pop_nodes: (pop_size, N, 5) + :argument pop_connections: (pop_size, 2, N, N) + + :return (pop_size, batch_size, output_num) + """ + res = [] + for cal_seqs, nodes, connections in zip(pop_cal_seqs, pop_nodes, pop_connections): + out = forward_batch(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) + res.append(out) + + return np.stack(res, axis=0) diff --git a/algorithms/neat/genome/numpy/genome.py b/algorithms/neat/genome/numpy/genome.py new file mode 100644 index 0000000..a5485e1 --- /dev/null +++ b/algorithms/neat/genome/numpy/genome.py @@ -0,0 +1,270 @@ +""" +Vectorization of genome representation. + +Utilizes Tuple[nodes: NDArray, connections: NDArray] to encode the genome, where: + +1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes +too large to be represented by the current value of N. +2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function +(act), and aggregation function (agg). +3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled +status. +Empty nodes or connections are represented using np.nan. + +""" +from typing import Tuple, Dict +from functools import partial + +import numpy as np +from numpy.typing import NDArray + +from algorithms.neat.genome.utils import fetch_first + +EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) + + +def create_initialize_function(config): + pop_size = config.neat.population.pop_size + N = config.basic.init_maximum_nodes + num_inputs = config.basic.num_inputs + num_outputs = config.basic.num_outputs + default_bias = config.neat.gene.bias.init_mean + default_response = config.neat.gene.response.init_mean + # default_act = config.neat.gene.activation.default + # default_agg = config.neat.gene.aggregation.default + default_act = 0 + default_agg = 0 + default_weight = config.neat.gene.weight.init_mean + return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response, + default_act, default_agg, default_weight) + + +def initialize_genomes(pop_size: int, + N: int, + num_inputs: int, num_outputs: int, + default_bias: float = 0.0, + default_response: float = 1.0, + default_act: int = 0, + default_agg: int = 0, + default_weight: float = 1.0) \ + -> Tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Initialize genomes with default values. + + Args: + pop_size (int): Number of genomes to initialize. + N (int): Maximum number of nodes in the network. + num_inputs (int): Number of input nodes. + num_outputs (int): Number of output nodes. + default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0. + default_response (float, optional): Default response value for output nodes. Defaults to 1.0. + default_act (int, optional): Default activation function index for output nodes. Defaults to 1. + default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0. + default_weight (float, optional): Default weight value for connections. Defaults to 0.0. + + Raises: + AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N. + + Returns: + Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays. + """ + # Reserve one row for potential mutation adding an extra node + assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \ + f"{num_inputs} and output_size: {num_outputs}!" + + pop_nodes = np.full((pop_size, N, 5), np.nan) + pop_connections = np.full((pop_size, 2, N, N), np.nan) + input_idx = np.arange(num_inputs) + output_idx = np.arange(num_inputs, num_inputs + num_outputs) + + pop_nodes[:, input_idx, 0] = input_idx + pop_nodes[:, output_idx, 0] = output_idx + + pop_nodes[:, output_idx, 1] = default_bias + pop_nodes[:, output_idx, 2] = default_response + pop_nodes[:, output_idx, 3] = default_act + pop_nodes[:, output_idx, 4] = default_agg + + for i in input_idx: + for j in output_idx: + pop_connections[:, 0, i, j] = default_weight + pop_connections[:, 1, i, j] = 1 + + return pop_nodes, pop_connections, input_idx, output_idx + + +def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: + """ + Expand the genome to accommodate more nodes. + :param pop_nodes: (pop_size, N, 5) + :param pop_connections: (pop_size, 2, N, N) + :param new_N: + :return: + """ + pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1] + + new_pop_nodes = np.full((pop_size, new_N, 5), np.nan) + new_pop_nodes[:, :old_N, :] = pop_nodes + + new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan) + new_pop_connections[:, :, :old_N, :old_N] = pop_connections + return new_pop_nodes, new_pop_connections + + +def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: + """ + Expand a single genome to accommodate more nodes. + :param nodes: (N, 5) + :param connections: (2, N, N) + :param new_N: + :return: + """ + old_N = nodes.shape[0] + new_nodes = np.full((new_N, 5), np.nan) + new_nodes[:old_N, :] = nodes + + new_connections = np.full((2, new_N, new_N), np.nan) + new_connections[:, :old_N, :old_N] = connections + + return new_nodes, new_connections + + +def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ + Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]: + """ + Convert a genome from array to dict. + :param nodes: (N, 5) + :param connections: (2, N, N) + :param output_keys: + :param input_keys: + :return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)] + """ + # update nodes_dict + try: + nodes_dict = {} + idx2key = {} + for i, node in enumerate(nodes): + if np.isnan(node[0]): + continue + key = int(node[0]) + assert key not in nodes_dict, f"Duplicate node key: {key}!" + + bias = node[1] if not np.isnan(node[1]) else None + response = node[2] if not np.isnan(node[2]) else None + act = node[3] if not np.isnan(node[3]) else None + agg = node[4] if not np.isnan(node[4]) else None + nodes_dict[key] = (bias, response, act, agg) + idx2key[i] = key + + # check nodes_dict + for i in input_keys: + assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" + bias, response, act, agg = nodes_dict[i] + assert bias is None and response is None and act is None and agg is None, \ + f"Input node {i} must has None bias, response, act, or agg!" + + for o in output_keys: + assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" + + for k, v in nodes_dict.items(): + if k not in input_keys: + bias, response, act, agg = v + assert bias is not None and response is not None and act is not None and agg is not None, \ + f"Normal node {k} must has non-None bias, response, act, or agg!" + + # update connections + connections_dict = {} + for i in range(connections.shape[1]): + for j in range(connections.shape[2]): + if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]): + continue + assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!" + assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!" + key = (idx2key[i], idx2key[j]) + + weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None + enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None + + assert weight is not None, f"Connection {key} must has non-None weight!" + assert enabled is not None, f"Connection {key} must has non-None enabled!" + connections_dict[key] = (weight, enabled) + + return nodes_dict, connections_dict + except AssertionError: + print(nodes) + print(connections) + raise AssertionError + + +def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys): + res = [] + for nodes, connections in zip(pop_nodes, pop_connections): + res.append(analysis(nodes, connections, input_keys, output_keys)) + return res + + +def add_node(new_node_key: int, nodes: NDArray, connections: NDArray, + bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]: + """ + add a new node to the genome. + """ + exist_keys = nodes[:, 0] + idx = fetch_first(np.isnan(exist_keys)) + nodes[idx] = np.array([new_node_key, bias, response, act, agg]) + return nodes, connections + + +def delete_node(node_key: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]: + """ + delete a node from the genome. only delete the node, regardless of connections. + """ + node_keys = nodes[:, 0] + idx = fetch_first(node_keys == node_key) + return delete_node_by_idx(idx, nodes, connections) + + +def delete_node_by_idx(idx: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]: + """ + delete a node from the genome. only delete the node, regardless of connections. + """ + nodes[idx] = EMPTY_NODE + return nodes, connections + + +def add_connection(from_node: int, to_node: int, nodes: NDArray, connections: NDArray, + weight: float = 0.0, enabled: bool = True) -> Tuple[NDArray, NDArray]: + """ + add a new connection to the genome. + """ + node_keys = nodes[:, 0] + from_idx = fetch_first(node_keys == from_node) + to_idx = fetch_first(node_keys == to_node) + return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled) + + +def add_connection_by_idx(from_idx: int, to_idx: int, nodes: NDArray, connections: NDArray, + weight: float = 0.0, enabled: bool = True) -> Tuple[NDArray, NDArray]: + """ + add a new connection to the genome. + """ + connections[:, from_idx, to_idx] = np.array([weight, enabled]) + return nodes, connections + + +def delete_connection(from_node: int, to_node: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]: + """ + delete a connection from the genome. + """ + node_keys = nodes[:, 0] + from_idx = fetch_first(node_keys == from_node) + to_idx = fetch_first(node_keys == to_node) + return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + + +def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: NDArray, connections: NDArray) -> Tuple[ + NDArray, NDArray]: + """ + delete a connection from the genome. + """ + connections[:, from_idx, to_idx] = np.nan + return nodes, connections diff --git a/algorithms/neat/genome/numpy/graph.py b/algorithms/neat/genome/numpy/graph.py new file mode 100644 index 0000000..dff90b8 --- /dev/null +++ b/algorithms/neat/genome/numpy/graph.py @@ -0,0 +1,163 @@ +""" +Some graph algorithms implemented in jax. +Only used in feed-forward networks. +""" + +import numpy as np +from numpy.typing import NDArray + +# from .utils import fetch_first, I_INT +from algorithms.neat.genome.utils import fetch_first, I_INT + + +def topological_sort(nodes: NDArray, connections: NDArray) -> NDArray: + """ + a jit-able version of topological_sort! that's crazy! + :param nodes: nodes array + :param connections: connections array + :return: topological sorted sequence + + Example: + nodes = np.array([ + [0], + [1], + [2], + [3] + ]) + connections = np.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + topological_sort(nodes, connections) -> [0, 1, 2, 3] + """ + connections_enable = connections[1, :, :] == 1 + in_degree = np.where(np.isnan(nodes[:, 0]), np.nan, np.sum(connections_enable, axis=0)) + res = np.full(in_degree.shape, I_INT) + idx = 0 + + for _ in range(in_degree.shape[0]): + i = fetch_first(in_degree == 0.) + if i == I_INT: + break + res[idx] = i + idx += 1 + in_degree[i] = -1 + children = connections_enable[i, :] + in_degree = np.where(children, in_degree - 1, in_degree) + + return res + + +def batch_topological_sort(pop_nodes: NDArray, pop_connections: NDArray) -> NDArray: + """ + batch version of topological_sort + :param pop_nodes: + :param pop_connections: + :return: + """ + res = [] + for nodes, connections in zip(pop_nodes, pop_connections): + seq = topological_sort(nodes, connections) + res.append(seq) + return np.stack(res, axis=0) + + +def check_cycles(nodes: NDArray, connections: NDArray, from_idx: NDArray, to_idx: NDArray) -> NDArray: + """ + Check whether a new connection (from_idx -> to_idx) will cause a cycle. + + :param nodes: JAX array + The array of nodes. + :param connections: JAX array + The array of connections. + :param from_idx: int + The index of the starting node. + :param to_idx: int + The index of the ending node. + :return: JAX array + An array indicating if there is a cycle caused by the new connection. + + Example: + nodes = np.array([ + [0], + [1], + [2], + [3] + ]) + connections = np.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + check_cycles(nodes, connections, 3, 2) -> True + check_cycles(nodes, connections, 2, 3) -> False + check_cycles(nodes, connections, 0, 3) -> False + check_cycles(nodes, connections, 1, 0) -> False + """ + connections_enable = ~np.isnan(connections[0, :, :]) + + connections_enable[from_idx, to_idx] = True + nodes_visited = np.full(nodes.shape[0], False) + nodes_visited[to_idx] = True + + for _ in range(nodes_visited.shape[0]): + new_visited = np.dot(nodes_visited, connections_enable) + nodes_visited = np.logical_or(nodes_visited, new_visited) + + return nodes_visited[from_idx] + + +if __name__ == '__main__': + nodes = np.array([ + [0], + [1], + [2], + [3], + [np.nan] + ]) + connections = np.array([ + [ + [np.nan, np.nan, 1, np.nan, np.nan], + [np.nan, np.nan, 1, 1, np.nan], + [np.nan, np.nan, np.nan, 1, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ], + [ + [np.nan, np.nan, 1, np.nan, np.nan], + [np.nan, np.nan, 1, 1, np.nan], + [np.nan, np.nan, np.nan, 1, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ] + ] + ) + + print(topological_sort(nodes, connections)) + print(topological_sort(nodes, connections)) + + print(check_cycles(nodes, connections, 3, 2)) + print(check_cycles(nodes, connections, 2, 3)) + print(check_cycles(nodes, connections, 0, 3)) + print(check_cycles(nodes, connections, 1, 0)) diff --git a/algorithms/neat/genome/numpy/mutate.py b/algorithms/neat/genome/numpy/mutate.py new file mode 100644 index 0000000..acd3739 --- /dev/null +++ b/algorithms/neat/genome/numpy/mutate.py @@ -0,0 +1,531 @@ +from typing import Tuple +from functools import partial + +import numpy as np +from numpy.typing import NDArray +from numpy.random import rand + +from .utils import fetch_random, fetch_first, I_INT +from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx +from .graph import check_cycles + + +def create_mutate_function(config, input_keys, output_keys, batch: bool): + """ + create mutate function for different situations + :param output_keys: + :param input_keys: + :param config: + :param batch: mutate for population or not + :return: + """ + bias = config.neat.gene.bias + bias_default = bias.init_mean + bias_mean = bias.init_mean + bias_std = bias.init_stdev + bias_mutate_strength = bias.mutate_power + bias_mutate_rate = bias.mutate_rate + bias_replace_rate = bias.replace_rate + + response = config.neat.gene.response + response_default = response.init_mean + response_mean = response.init_mean + response_std = response.init_stdev + response_mutate_strength = response.mutate_power + response_mutate_rate = response.mutate_rate + response_replace_rate = response.replace_rate + + weight = config.neat.gene.weight + weight_mean = weight.init_mean + weight_std = weight.init_stdev + weight_mutate_strength = weight.mutate_power + weight_mutate_rate = weight.mutate_rate + weight_replace_rate = weight.replace_rate + + activation = config.neat.gene.activation + # act_default = activation.default + act_default = 0 + act_range = len(activation.options) + act_replace_rate = activation.mutate_rate + + aggregation = config.neat.gene.aggregation + # agg_default = aggregation.default + agg_default = 0 + agg_range = len(aggregation.options) + agg_replace_rate = aggregation.mutate_rate + + enabled = config.neat.gene.enabled + enabled_reverse_rate = enabled.mutate_rate + + genome = config.neat.genome + add_node_rate = genome.node_add_prob + delete_node_rate = genome.node_delete_prob + add_connection_rate = genome.conn_add_prob + delete_connection_rate = genome.conn_delete_prob + single_structure_mutate = genome.single_structural_mutation + + mutate_func = lambda nodes, connections, new_node_key: \ + mutate(nodes, connections, new_node_key, input_keys, output_keys, + bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, + bias_replace_rate, response_default, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, + weight_replace_rate, act_default, act_range, act_replace_rate, + agg_default, agg_range, agg_replace_rate, enabled_reverse_rate, + add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, + single_structure_mutate) + + if not batch: + return mutate_func + else: + def batch_mutate_func(pop_nodes, pop_connections, new_node_keys): + res_nodes, res_connections = [], [] + for nodes, connections, new_node_key in zip(pop_nodes, pop_connections, new_node_keys): + nodes, connections = mutate_func(nodes, connections, new_node_key) + res_nodes.append(nodes) + res_connections.append(connections) + return np.stack(res_nodes, axis=0), np.stack(res_connections, axis=0) + + return batch_mutate_func + + +def mutate(nodes: NDArray, + connections: NDArray, + new_node_key: int, + input_keys: NDArray, + output_keys: NDArray, + bias_default: float = 0, + bias_mean: float = 0, + bias_std: float = 1, + bias_mutate_strength: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + response_default: float = 1, + response_mean: float = 1., + response_std: float = 0., + response_mutate_strength: float = 0., + response_mutate_rate: float = 0., + response_replace_rate: float = 0., + weight_mean: float = 0., + weight_std: float = 1., + weight_mutate_strength: float = 0.5, + weight_mutate_rate: float = 0.7, + weight_replace_rate: float = 0.1, + act_default: int = 0, + act_range: int = 5, + act_replace_rate: float = 0.1, + agg_default: int = 0, + agg_range: int = 5, + agg_replace_rate: float = 0.1, + enabled_reverse_rate: float = 0.1, + add_node_rate: float = 0.2, + delete_node_rate: float = 0.2, + add_connection_rate: float = 0.4, + delete_connection_rate: float = 0.4, + single_structure_mutate: bool = True): + """ + :param output_keys: + :param input_keys: + :param agg_default: + :param act_default: + :param response_default: + :param bias_default: + :param nodes: (N, 5) + :param connections: (2, N, N) + :param new_node_key: + :param bias_mean: + :param bias_std: + :param bias_mutate_strength: + :param bias_mutate_rate: + :param bias_replace_rate: + :param response_mean: + :param response_std: + :param response_mutate_strength: + :param response_mutate_rate: + :param response_replace_rate: + :param weight_mean: + :param weight_std: + :param weight_mutate_strength: + :param weight_mutate_rate: + :param weight_replace_rate: + :param act_range: + :param act_replace_rate: + :param agg_range: + :param agg_replace_rate: + :param enabled_reverse_rate: + :param add_node_rate: + :param delete_node_rate: + :param add_connection_rate: + :param delete_connection_rate: + :param single_structure_mutate: a genome is structurally mutate at most once + :return: + """ + + # mutate_structure + def nothing(n, c): + return n, c + + def m_add_node(n, c): + return mutate_add_node(new_node_key, n, c, bias_default, response_default, act_default, agg_default) + + def m_delete_node(n, c): + return mutate_delete_node(n, c, input_keys, output_keys) + + def m_add_connection(n, c): + return mutate_add_connection(n, c, input_keys, output_keys) + + def m_delete_connection(n, c): + return mutate_delete_connection(n, c) + + if single_structure_mutate: + d = np.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate) + + # shorten variable names for beauty + anr, dnr = add_node_rate / d, delete_node_rate / d + acr, dcr = add_connection_rate / d, delete_connection_rate / d + + r = rand() + if r <= anr: + nodes, connections = m_add_node(nodes, connections) + elif r <= anr + dnr: + nodes, connections = m_delete_node(nodes, connections) + elif r <= anr + dnr + acr: + nodes, connections = m_add_connection(nodes, connections) + elif r <= anr + dnr + acr + dcr: + nodes, connections = m_delete_connection(nodes, connections) + else: + pass # do nothing + + else: + # mutate add node + if rand() < add_node_rate: + nodes, connections = m_add_node(nodes, connections) + + # mutate delete node + if rand() < delete_node_rate: + nodes, connections = m_delete_node(nodes, connections) + + # mutate add connection + if rand() < add_connection_rate: + nodes, connections = m_add_connection(nodes, connections) + + # mutate delete connection + if rand() < delete_connection_rate: + nodes, connections = m_delete_connection(nodes, connections) + + nodes, connections = mutate_values(nodes, connections, bias_mean, bias_std, bias_mutate_strength, + bias_mutate_rate, bias_replace_rate, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, + weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range, + agg_replace_rate, enabled_reverse_rate) + + return nodes, connections + + +def mutate_values(nodes: NDArray, + connections: NDArray, + bias_mean: float = 0, + bias_std: float = 1, + bias_mutate_strength: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + response_mean: float = 1., + response_std: float = 0., + response_mutate_strength: float = 0., + response_mutate_rate: float = 0., + response_replace_rate: float = 0., + weight_mean: float = 0., + weight_std: float = 1., + weight_mutate_strength: float = 0.5, + weight_mutate_rate: float = 0.7, + weight_replace_rate: float = 0.1, + act_range: int = 5, + act_replace_rate: float = 0.1, + agg_range: int = 5, + agg_replace_rate: float = 0.1, + enabled_reverse_rate: float = 0.1) -> Tuple[NDArray, NDArray]: + """ + Mutate values of nodes and connections. + + Args: + nodes: A 2D array representing nodes. + connections: A 3D array representing connections. + bias_mean: Mean of the bias values. + bias_std: Standard deviation of the bias values. + bias_mutate_strength: Strength of the bias mutation. + bias_mutate_rate: Rate of the bias mutation. + bias_replace_rate: Rate of the bias replacement. + response_mean: Mean of the response values. + response_std: Standard deviation of the response values. + response_mutate_strength: Strength of the response mutation. + response_mutate_rate: Rate of the response mutation. + response_replace_rate: Rate of the response replacement. + weight_mean: Mean of the weight values. + weight_std: Standard deviation of the weight values. + weight_mutate_strength: Strength of the weight mutation. + weight_mutate_rate: Rate of the weight mutation. + weight_replace_rate: Rate of the weight replacement. + act_range: Range of the activation function values. + act_replace_rate: Rate of the activation function replacement. + agg_range: Range of the aggregation function values. + agg_replace_rate: Rate of the aggregation function replacement. + enabled_reverse_rate: Rate of reversing enabled state of connections. + + Returns: + A tuple containing mutated nodes and connections. + """ + + bias_new = mutate_float_values(nodes[:, 1], bias_mean, bias_std, + bias_mutate_strength, bias_mutate_rate, bias_replace_rate) + response_new = mutate_float_values(nodes[:, 2], response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate) + weight_new = mutate_float_values(connections[0, :, :], weight_mean, weight_std, + weight_mutate_strength, weight_mutate_rate, weight_replace_rate) + act_new = mutate_int_values(nodes[:, 3], act_range, act_replace_rate) + agg_new = mutate_int_values(nodes[:, 4], agg_range, agg_replace_rate) + + # refactor enabled + r = np.random.rand(*connections[1, :, :].shape) + enabled_new = connections[1, :, :] == 1 + enabled_new = np.where(r < enabled_reverse_rate, ~enabled_new, enabled_new) + enabled_new = np.where(~np.isnan(connections[0, :, :]), enabled_new, np.nan) + + nodes[:, 1] = bias_new + nodes[:, 2] = response_new + nodes[:, 3] = act_new + nodes[:, 4] = agg_new + connections[0, :, :] = weight_new + connections[1, :, :] = enabled_new + + return nodes, connections + + +def mutate_float_values(old_vals: NDArray, mean: float, std: float, + mutate_strength: float, mutate_rate: float, replace_rate: float) -> NDArray: + """ + Mutate float values of a given array. + + Args: + old_vals: A 1D array of float values to be mutated. + mean: Mean of the values. + std: Standard deviation of the values. + mutate_strength: Strength of the mutation. + mutate_rate: Rate of the mutation. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of float values. + """ + noise = np.random.normal(size=old_vals.shape) * mutate_strength + replace = np.random.normal(size=old_vals.shape) * std + mean + r = rand(*old_vals.shape) + new_vals = old_vals + new_vals = np.where(r < mutate_rate, new_vals + noise, new_vals) + new_vals = np.where( + np.logical_and(mutate_rate < r, r < mutate_rate + replace_rate), + replace, + new_vals + ) + new_vals = np.where(~np.isnan(old_vals), new_vals, np.nan) + return new_vals + + +def mutate_int_values(old_vals: NDArray, range: int, replace_rate: float) -> NDArray: + """ + Mutate integer values (act, agg) of a given array. + + Args: + old_vals: A 1D array of integer values to be mutated. + range: Range of the integer values. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of integer values. + """ + replace_val = np.random.randint(low=0, high=range, size=old_vals.shape) + r = np.random.rand(*old_vals.shape) + new_vals = old_vals + new_vals = np.where(r < replace_rate, replace_val, new_vals) + new_vals = np.where(~np.isnan(old_vals), new_vals, np.nan) + return new_vals + + +def mutate_add_node(new_node_key: int, nodes: NDArray, connections: NDArray, + default_bias: float = 0, default_response: float = 1, + default_act: int = 0, default_agg: int = 0) -> Tuple[NDArray, NDArray]: + """ + Randomly add a new node from splitting a connection. + :param new_node_key: + :param nodes: + :param connections: + :param default_bias: + :param default_response: + :param default_act: + :param default_agg: + :return: + """ + # randomly choose a connection + from_key, to_key, from_idx, to_idx = choice_connection_key(nodes, connections) + + def nothing(): + return nodes, connections + + def successful_add_node(): + # disable the connection + new_nodes, new_connections = nodes, connections + new_connections[1, from_idx, to_idx] = False + + # add a new node + new_nodes, new_connections = \ + add_node(new_node_key, new_nodes, new_connections, + bias=default_bias, response=default_response, act=default_act, agg=default_agg) + new_idx = fetch_first(new_nodes[:, 0] == new_node_key) + + # add two new connections + weight = new_connections[0, from_idx, to_idx] + new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, + new_nodes, new_connections, weight=0, enabled=True) + new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, + new_nodes, new_connections, weight=weight, enabled=True) + return new_nodes, new_connections + + # if from_idx == I_INT, that means no connection exist, do nothing + if from_idx == I_INT: + nodes, connections = nothing() + else: + nodes, connections = successful_add_node() + + return nodes, connections + + +def mutate_delete_node(nodes: NDArray, connections: NDArray, + input_keys: NDArray, output_keys: NDArray) -> Tuple[NDArray, NDArray]: + """ + Randomly delete a node. Input and output nodes are not allowed to be deleted. + :param nodes: + :param connections: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose a node + node_key, node_idx = choice_node_key(nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=False) + + if np.isnan(node_key): + return nodes, connections + + # delete the node + aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) + + # delete connections + aux_connections[:, node_idx, :] = np.nan + aux_connections[:, :, node_idx] = np.nan + + return aux_nodes, aux_connections + + +def mutate_add_connection(nodes: NDArray, connections: NDArray, + input_keys: NDArray, output_keys: NDArray) -> Tuple[NDArray, NDArray]: + """ + Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, + cycles are not allowed. + :param nodes: + :param connections: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose two nodes + from_key, from_idx = choice_node_key(nodes, input_keys, output_keys, + allow_input_keys=True, allow_output_keys=True) + to_key, to_idx = choice_node_key(nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=True) + + is_already_exist = ~np.isnan(connections[0, from_idx, to_idx]) + + if is_already_exist: + connections[1, from_idx, to_idx] = True + return nodes, connections + elif check_cycles(nodes, connections, from_idx, to_idx): + return nodes, connections + else: + new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections) + return new_nodes, new_connections + + +def mutate_delete_connection(nodes: NDArray, connections: NDArray): + """ + Randomly delete a connection. + :param nodes: + :param connections: + :return: + """ + from_key, to_key, from_idx, to_idx = choice_connection_key(nodes, connections) + + def nothing(): + return nodes, connections + + def successfully_delete_connection(): + return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + + if from_idx == I_INT: + nodes, connections = nothing() + else: + nodes, connections = successfully_delete_connection() + + return nodes, connections + + +def choice_node_key(nodes: NDArray, + input_keys: NDArray, output_keys: NDArray, + allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[NDArray, NDArray]: + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param nodes: + :param input_keys: + :param output_keys: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~np.isnan(node_keys) + + if not allow_input_keys: + mask = np.logical_and(mask, ~np.isin(node_keys, input_keys)) + + if not allow_output_keys: + mask = np.logical_and(mask, ~np.isin(node_keys, output_keys)) + + idx = fetch_random(mask) + + if idx == I_INT: + return np.nan, idx + else: + return node_keys[idx], idx + + +def choice_connection_key(nodes: NDArray, connection: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Randomly choose a connection key from the given connections. + :param nodes: + :param connection: + :return: from_key, to_key, from_idx, to_idx + """ + has_connections_row = np.any(~np.isnan(connection[0, :, :]), axis=1) + from_idx = fetch_random(has_connections_row) + + if from_idx == I_INT: + return np.nan, np.nan, from_idx, I_INT + + col = connection[0, from_idx, :] + to_idx = fetch_random(~np.isnan(col)) + from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0] + + from_key = np.where(from_idx != I_INT, from_key, np.nan) + to_key = np.where(to_idx != I_INT, to_key, np.nan) + + return from_key, to_key, from_idx, to_idx diff --git a/algorithms/neat/genome/numpy/utils.py b/algorithms/neat/genome/numpy/utils.py new file mode 100644 index 0000000..dec1983 --- /dev/null +++ b/algorithms/neat/genome/numpy/utils.py @@ -0,0 +1,128 @@ +from functools import partial +from typing import Tuple + +import numpy as np +from numpy.typing import NDArray + +I_INT = np.iinfo(np.int32).max # infinite int + + +def flatten_connections(keys, connections): + """ + flatten the (2, N, N) connections to (N * N, 4) + :param keys: + :param connections: + :return: + the first two columns are the index of the node + the 3rd column is the weight, and the 4th column is the enabled status + """ + indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij') + indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2) + + # make (2, N, N) to (N, N, 2) + con = np.transpose(connections, (1, 2, 0)) + # make (N, N, 2) to (N * N, 2) + con = np.reshape(con, (-1, 2)) + + con = np.concatenate((indices, con), axis=1) + return con + + +def unflatten_connections(N, cons): + """ + restore the (N * N, 4) connections to (2, N, N) + :param N: + :param cons: + :return: + """ + cons = cons[:, 2:] # remove the indices + unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0) + return unflatten_cons + + +def set_operation_analysis(ar1: NDArray, ar2: NDArray) -> Tuple[NDArray, NDArray, NDArray]: + """ + Analyze the intersection and union of two arrays by returning their sorted concatenation indices, + intersection mask, and union mask. + + :param ar1: JAX array of shape (N, M) + First input array. Should have the same shape as ar2. + :param ar2: JAX array of shape (N, M) + Second input array. Should have the same shape as ar1. + :return: tuple of 3 arrays + - sorted_indices: Indices that would sort the concatenation of ar1 and ar2. + - intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2 + in the sorted concatenation. + - union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2 + in the sorted concatenation. + + Examples: + a = np.array([[1, 2], [3, 4], [5, 6]]) + b = np.array([[1, 2], [7, 8], [9, 10]]) + + sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b) + + sorted_indices -> array([0, 1, 2, 3, 4, 5]) + intersect_mask -> array([True, False, False, False, False, False]) + union_mask -> array([False, True, True, True, True, True]) + """ + ar = np.concatenate((ar1, ar2), axis=0) + sorted_indices = np.lexsort(ar.T[::-1]) + aux = ar[sorted_indices] + aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0) + nan_mask = np.any(np.isnan(aux), axis=1) + + fr, sr = aux[:-1], aux[1:] # first row, second row + intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1] + union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1] + return sorted_indices, intersect_mask, union_mask + + +def fetch_first(mask, default=I_INT) -> NDArray: + """ + fetch the first True index + :param mask: array of bool + :param default: the default value if no element satisfying the condition + :return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT + example: + >>> a = np.array([1, 2, 3, 4, 5]) + >>> fetch_first(a > 3) + 3 + >>> fetch_first(a > 30) + I_INT + """ + idx = np.argmax(mask) + return np.where(mask[idx], idx, default) + + +def fetch_last(mask, default=I_INT) -> NDArray: + """ + similar to fetch_first, but fetch the last True index + """ + reversed_idx = fetch_first(mask[::-1], default) + return np.where(reversed_idx == default, default, mask.shape[0] - reversed_idx - 1) + + +def fetch_random(mask, default=I_INT) -> NDArray: + """ + similar to fetch_first, but fetch a random True index + """ + true_cnt = np.sum(mask) + if true_cnt == 0: + return default + cumsum = np.cumsum(mask) + target = np.random.randint(1, true_cnt + 1, size=()) + return fetch_first(cumsum >= target, default) + + +if __name__ == '__main__': + a = np.array([1, 2, 3, 4, 5]) + print(fetch_first(a > 3)) + print(fetch_first(a > 30)) + + print(fetch_last(a > 3)) + print(fetch_last(a > 30)) + + for t in [-1, 0, 1, 2, 3, 4, 5]: + for _ in range(10): + print(t, fetch_random(a > t)) diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 57eef00..ee883e0 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -117,10 +117,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: true_cnt = jnp.sum(mask) cumsum = jnp.cumsum(mask) target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) - return fetch_first(cumsum >= target, default) + mask = jnp.where(true_cnt == 0, False, cumsum >= target) + return fetch_first(mask, default) if __name__ == '__main__': + a = jnp.array([1, 2, 3, 4, 5]) print(fetch_first(a > 3)) print(fetch_first(a > 30)) @@ -129,6 +131,9 @@ if __name__ == '__main__': print(fetch_last(a > 30)) rand_key = jax.random.PRNGKey(0) - for _ in range(100): - rand_key, _ = jax.random.split(rand_key) - print(fetch_random(rand_key, a > 0)) + + for t in [-1, 0, 1, 2, 3, 4, 5]: + for _ in range(10): + rand_key, _ = jax.random.split(rand_key) + print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2)) + print(t, fetch_random(rand_key, a > t)) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 8dd0e21..7d931df 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -1,15 +1,12 @@ from typing import List, Union, Tuple, Callable import time -import jax import numpy as np from .species import SpeciesController -from .genome import create_initialize_function, create_mutate_function, create_forward_function -from .genome import batch_crossover -from .genome.crossover import crossover -from .genome import expand, expand_single -from algorithms.neat.genome.genome import pop_analysis, analysis +from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function +from .genome.numpy import batch_crossover +from .genome.numpy import expand, expand_single class Pipeline: @@ -18,7 +15,7 @@ class Pipeline: """ def __init__(self, config, seed=42): - self.randkey = jax.random.PRNGKey(seed) + np.random.seed(seed) self.config = config self.N = config.basic.init_maximum_nodes @@ -53,14 +50,6 @@ class Pipeline: def tell(self, fitnesses): self.generation += 1 - for i, f in enumerate(fitnesses): - if np.isnan(f): - print("fuck!!!!!!!!!!!!!!") - error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i] - np.save('error_nodes.npy', error_nodes) - np.save('error_connections.npy', error_connections) - assert False - self.species_controller.update_species_fitnesses(fitnesses) crossover_pair = self.species_controller.reproduce(self.generation) @@ -96,8 +85,6 @@ class Pipeline: assert self.pop_nodes.shape[0] == self.pop_size - k1, k2, self.randkey = jax.random.split(self.randkey, 3) - # crossover # prepare elitism mask and crossover pair elitism_mask = np.full(self.pop_size, False) @@ -112,18 +99,13 @@ class Pipeline: wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections - crossover_rand_keys = jax.random.split(k1, self.pop_size) # npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc) + npn, npc = batch_crossover(wpn, wpc, lpn, lpc) # print(pop_analysis(npn, npc, self.input_idx, self.output_idx)) # mutate new_node_keys = np.array(self.fetch_new_node_keys()) - mutate_rand_keys = jax.random.split(k2, self.pop_size) - m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes - m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) - - # print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx)) + m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes # elitism don't mutate # (pop_size, ) to (pop_size, 1, 1) @@ -180,21 +162,4 @@ class Pipeline: self.generation_timestamp = new_timestamp print(f"Generation: {self.generation}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") - - # def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc): - # pop_nodes, pop_connections = [], [] - # for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc): - # new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc) - # pop_nodes.append(new_nodes) - # pop_connections.append(new_connections) - # try: - # print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx)) - # except AssertionError: - # new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc) - # return np.stack(pop_nodes), np.stack(pop_connections) - - # return batch_crossover(*args) - -def crossover_wrapper(*args): - return batch_crossover(*args) + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") \ No newline at end of file diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index d513c82..053a71b 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -1,10 +1,9 @@ from typing import List, Tuple, Dict, Union from itertools import count -import jax import numpy as np from numpy.typing import NDArray -from .genome import distance +from .genome.numpy import distance class Species(object): @@ -46,10 +45,6 @@ class SpeciesController: self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species - self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many - # self.o2o_distance_func = np_distance # one to one - self.o2o_distance_func = distance - def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: """ :param pop_nodes: @@ -67,8 +62,7 @@ class SpeciesController: for sid, species in self.species.items(): # calculate the distance between the representative and the population r_nodes, r_connections = species.representative - distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections) - distances = jax.device_get(distances) # fetch the data from gpu + distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance new_representatives[sid] = min_idx @@ -81,9 +75,7 @@ class SpeciesController: if previous_species_list: # exist previous species rid_list = [new_representatives[sid] for sid in previous_species_list] res_pop_distance = [ - jax.device_get( - self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) - ) + o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) for rid in rid_list ] @@ -110,7 +102,7 @@ class SpeciesController: sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid ] distances = np.array(distances) @@ -267,36 +259,6 @@ class SpeciesController: return crossover_pair - def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections): - # distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections) - # for d in distances: - # if np.isnan(d): - # print("fuck!!!!!!!!!!!!!!") - # print(distances) - # assert False - # return distances - distances = [] - for nodes, connections in zip(pop_nodes, pop_connections): - d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections) - if np.isnan(d) or d > 20: - np.save("too_large_distance_r_nodes.npy", r_nodes) - np.save("too_large_distance_r_connections.npy", r_connections) - np.save("too_large_distance_nodes", nodes) - np.save("too_large_distance_connections.npy", connections) - d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections) - assert False - distances.append(d) - distances = np.stack(distances, axis=0) - # print(distances) - return distances - - def o2o_distance_wrapper(self, *keys): - d = self.o2o_distance_func(*keys) - if np.isnan(d): - print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - assert False - return d - def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): """ @@ -351,3 +313,12 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \ sorted_members = [item[0] for item in sorted_combined] sorted_fitnesses = [item[1] for item in sorted_combined] return sorted_members, sorted_fitnesses + + +def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections): + distances = [] + for nodes, connections in zip(pop_nodes, pop_connections): + d = distance(r_nodes, r_connections, nodes, connections) + distances.append(d) + distances = np.stack(distances, axis=0) + return distances diff --git a/algorithms/numpy/__init__.py b/algorithms/numpy/__init__.py deleted file mode 100644 index 0c1d0ac..0000000 --- a/algorithms/numpy/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -numpy version of functions in genome -""" -from .distance import distance -from .utils import * \ No newline at end of file diff --git a/algorithms/numpy/distance.py b/algorithms/numpy/distance.py deleted file mode 100644 index e56f2ff..0000000 --- a/algorithms/numpy/distance.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np - -from .utils import flatten_connections, set_operation_analysis - - -def distance(nodes1, connections1, nodes2, connections2): - node_distance = gene_distance(nodes1, nodes2, 'node') - - # refactor connections - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] - cons1 = flatten_connections(keys1, connections1) - cons2 = flatten_connections(keys2, connections2) - - connection_distance = gene_distance(cons1, cons2, 'connection') - return node_distance + connection_distance - - -def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): - if gene_type == 'node': - keys1, keys2 = ar1[:, :1], ar2[:, :1] - else: # connection - keys1, keys2 = ar1[:, :2], ar2[:, :2] - - n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) - nodes = np.concatenate((ar1, ar2), axis=0) - sorted_nodes = nodes[n_sorted_indices] - fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] - - non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask) - if gene_type == 'node': - node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) - else: # connection - node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) - - node_distance = np.where(np.isnan(node_distance), 0, node_distance) - homologous_distance = np.sum(node_distance * n_intersect_mask[:-1]) - - gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1)) - gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1)) - - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - return val / np.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2) - - -def homologous_node_distance(n1, n2): - d = 0 - d += np.abs(n1[:, 1] - n2[:, 1]) # bias - d += np.abs(n1[:, 2] - n2[:, 2]) # response - d += n1[:, 3] != n2[:, 3] # activation - d += n1[:, 4] != n2[:, 4] - return d - - -def homologous_connection_distance(c1, c2): - d = 0 - d += np.abs(c1[:, 2] - c2[:, 2]) # weight - d += c1[:, 3] != c2[:, 3] # enable - return d diff --git a/algorithms/numpy/utils.py b/algorithms/numpy/utils.py deleted file mode 100644 index 57119a3..0000000 --- a/algorithms/numpy/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np - -I_INT = np.iinfo(np.int32).max # infinite int - - -def flatten_connections(keys, connections): - indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij') - indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2) - - # make (2, N, N) to (N, N, 2) - con = np.transpose(connections, (1, 2, 0)) - # make (N, N, 2) to (N * N, 2) - con = np.reshape(con, (-1, 2)) - - con = np.concatenate((indices, con), axis=1) - return con - - -def unflatten_connections(N, cons): - cons = cons[:, 2:] # remove the indices - unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0) - return unflatten_cons - - -def set_operation_analysis(ar1, ar2): - ar = np.concatenate((ar1, ar2), axis=0) - sorted_indices = np.lexsort(ar.T[::-1]) - aux = ar[sorted_indices] - aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0) - nan_mask = np.any(np.isnan(aux), axis=1) - - fr, sr = aux[:-1], aux[1:] # first row, second row - intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1] - union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1] - return sorted_indices, intersect_mask, union_mask - - -def fetch_first(mask, default=I_INT): - idx = np.argmax(mask) - return np.where(mask[idx], idx, default) - - -def fetch_last(mask, default=I_INT): - reversed_idx = fetch_first(mask[::-1], default) - return np.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1) - - -def fetch_random(rand_key, mask, default=I_INT): - """ - similar to fetch_first, but fetch a random True index - """ - true_cnt = np.sum(mask) - cumsum = np.cumsum(mask) - target = np.random.randint(rand_key, shape=(), minval=0, maxval=true_cnt + 1) - return fetch_first(cumsum >= target, default)