From 414b620dc82006cdddea84ac9dac8d65bb9b0d4f Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 6 May 2023 23:26:13 +0800 Subject: [PATCH] =?UTF-8?q?=E8=99=BD=E7=84=B6xor=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E8=BF=98=E6=98=AF=E8=B7=91=E4=B8=8D=E5=87=BA=E6=9D=A5=EF=BC=8C?= =?UTF-8?q?=E4=BD=86=E8=87=B3=E5=B0=91=E5=B7=B2=E7=BB=8F=E7=A1=AE=E5=AE=9A?= =?UTF-8?q?=E4=B8=8D=E6=98=AFdistance=E7=9A=84=E9=94=99=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- algorithms/neat/genome/numpy/__init__.py | 2 +- algorithms/neat/genome/numpy/distance.py | 64 +++++++++++++++++- algorithms/neat/species.py | 4 ++ examples/distance_test.py | 84 ++++++++++++++++++++++++ examples/xor.py | 4 +- 5 files changed, 151 insertions(+), 7 deletions(-) create mode 100644 examples/distance_test.py diff --git a/algorithms/neat/genome/numpy/__init__.py b/algorithms/neat/genome/numpy/__init__.py index bb905b5..c25df1a 100644 --- a/algorithms/neat/genome/numpy/__init__.py +++ b/algorithms/neat/genome/numpy/__init__.py @@ -1,4 +1,4 @@ -from .genome import create_initialize_function, expand, expand_single +from .genome import create_initialize_function, expand, expand_single, analysis from .distance import distance from .mutate import create_mutate_function from .forward import create_forward_function diff --git a/algorithms/neat/genome/numpy/distance.py b/algorithms/neat/genome/numpy/distance.py index b5d313a..8677aff 100644 --- a/algorithms/neat/genome/numpy/distance.py +++ b/algorithms/neat/genome/numpy/distance.py @@ -5,6 +5,9 @@ from numpy.typing import NDArray from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis +EMPTY_NODE = np.full((1, 5), np.nan) +EMPTY_CON = np.full((1, 4), np.nan) + def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) -> NDArray: """ @@ -13,15 +16,70 @@ def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connection connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] """ - node_distance = gene_distance(nodes1, nodes2, 'node') + nd = node_distance(nodes1, nodes2) # node distance # refactor connections keys1, keys2 = nodes1[:, 0], nodes2[:, 0] cons1 = flatten_connections(keys1, connections1) cons2 = flatten_connections(keys2, connections2) + cd = connection_distance(cons1, cons2) # connection distance + return nd + cd - connection_distance = gene_distance(cons1, cons2, 'connection') - return node_distance + connection_distance + +def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): + node_cnt1 = np.sum(~np.isnan(nodes1[:, 0])) + node_cnt2 = np.sum(~np.isnan(nodes2[:, 0])) + max_cnt = np.maximum(node_cnt1, node_cnt2) + + nodes = np.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = np.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = np.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row + nan_mask = np.isnan(nodes[:, 0]) + + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1] + + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * np.sum(intersect_mask) + nd = batch_homologous_node_distance(fr, sr) + nd = np.where(np.isnan(nd), 0, nd) + homologous_distance = np.sum(nd * intersect_mask) + + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + + if max_cnt == 0: # consider the case that both genome has no gene + return 0 + else: + return val / max_cnt + + +def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): + con_cnt1 = np.sum(~np.isnan(cons1[:, 2])) # weight is not nan, means the connection exists + con_cnt2 = np.sum(~np.isnan(cons2[:, 2])) + max_cnt = np.maximum(con_cnt1, con_cnt2) + + cons = np.concatenate((cons1, cons2), axis=0) + keys = cons[:, :2] + sorted_indices = np.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = np.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row + + # both genome has such connection + intersect_mask = np.all(fr[:, :2] == sr[:, :2], axis=1) & ~np.isnan(fr[:, 2]) & ~np.isnan(sr[:, 2]) + + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * np.sum(intersect_mask) + cd = batch_homologous_connection_distance(fr, sr) + cd = np.where(np.isnan(cd), 0, cd) + homologous_distance = np.sum(cd * intersect_mask) + + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + + if max_cnt == 0: # consider the case that both genome has no gene + return 0 + else: + return val / max_cnt def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 053a71b..1412687 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -319,6 +319,10 @@ def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections): distances = [] for nodes, connections in zip(pop_nodes, pop_connections): d = distance(r_nodes, r_connections, nodes, connections) + if d < 0: + d = distance(r_nodes, r_connections, nodes, connections) + print(d) + assert False distances.append(d) distances = np.stack(distances, axis=0) return distances diff --git a/examples/distance_test.py b/examples/distance_test.py new file mode 100644 index 0000000..1a3ea28 --- /dev/null +++ b/examples/distance_test.py @@ -0,0 +1,84 @@ +from typing import Callable, List +from functools import partial + +import numpy as np + +from utils import Configer +from algorithms.neat.genome.numpy import analysis, distance +from algorithms.neat.genome.numpy import create_initialize_function, create_mutate_function + + +def real_distance(nodes1, connections1, nodes2, connections2, input_idx, output_idx): + nodes1, connections1 = analysis(nodes1, connections1, input_idx, output_idx) + nodes2, connections2 = analysis(nodes2, connections2, input_idx, output_idx) + compatibility_coe = 0.5 + disjoint_coe = 1. + node_distance = 0.0 + if nodes1 or nodes2: # otherwise, both are empty + disjoint_nodes = 0 + for k2 in nodes2: + if k2 not in nodes1: + disjoint_nodes += 1 + + for k1, n1 in nodes1.items(): + n2 = nodes2.get(k1) + if n2 is None: + disjoint_nodes += 1 + else: + if n1[0] is None: + continue + d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1]) + d += 1 if n1[2] != n2[2] else 0 + d += 1 if n1[3] != n2[3] else 0 + node_distance += d + + max_nodes = max(len(nodes1), len(nodes2)) + node_distance = (compatibility_coe * node_distance + disjoint_coe * disjoint_nodes) / max_nodes + + connection_distance = 0.0 + if connections1 or connections2: + disjoint_connections = 0 + for k2 in connections2: + if k2 not in connections1: + disjoint_connections += 1 + + for k1, c1 in connections1.items(): + c2 = connections2.get(k1) + if c2 is None: + disjoint_connections += 1 + else: + # Homologous genes compute their own distance value. + d = abs(c1[0] - c2[0]) + d += 1 if c1[1] != c2[1] else 0 + connection_distance += d + max_conn = max(len(connections1), len(connections2)) + connection_distance = (compatibility_coe * connection_distance + disjoint_coe * disjoint_connections) / max_conn + + return node_distance + connection_distance + + +def main(): + config = Configer.load_config() + keys_idx = config.basic.num_inputs + config.basic.num_outputs + pop_size = config.neat.population.pop_size + init_func = create_initialize_function(config) + pop_nodes, pop_connections, input_idx, output_idx = init_func() + + mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True) + + while True: + pop_nodes, pop_connections = mutate_func(pop_nodes, pop_connections, list(range(keys_idx, keys_idx + pop_size))) + keys_idx += pop_size + for i in range(pop_size): + for j in range(pop_size): + nodes1, connections1 = pop_nodes[i], pop_connections[i] + nodes2, connections2 = pop_nodes[j], pop_connections[j] + numpy_d = distance(nodes1, connections1, nodes2, connections2) + real_d = real_distance(nodes1, connections1, nodes2, connections2, input_idx, output_idx) + assert np.isclose(numpy_d, real_d), f'{numpy_d} != {real_d}' + print(numpy_d, real_d) + + +if __name__ == '__main__': + np.random.seed(0) + main() diff --git a/examples/xor.py b/examples/xor.py index 044d0a6..0ff7558 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,7 +1,6 @@ from typing import Callable, List from functools import partial -import jax import numpy as np from utils import Configer @@ -18,8 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]: :return: """ outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = -np.mean((outs - xor_outputs) ** 2, axis=(1, 2)) + fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2)) # print(fitnesses) return fitnesses.tolist() # returns a list