From 64f8eaccaf498499d597ec1a4025efdb6f385252 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 7 May 2023 23:47:53 +0800 Subject: [PATCH] create function "distance_numpy", serve as o2o distance function --- algorithms/neat/genome/distance.py | 93 +++++++++++++++++++++++++++++- examples/xor.py | 4 +- utils/default_config.json | 4 +- 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index c78cd7e..74f1df6 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -1,5 +1,7 @@ from jax import jit, vmap, Array from jax import numpy as jnp +import numpy as np +from numpy.typing import NDArray from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON @@ -14,7 +16,11 @@ def create_distance_function(config, type: str): compatibility_coe = config.neat.genome.compatibility_weight_coefficient if type == 'o2o': return lambda nodes1, connections1, nodes2, connections2: \ - distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + + # return lambda nodes1, connections1, nodes2, connections2: \ + # distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + elif type == 'o2m': func = vmap(distance, in_axes=(None, None, 0, 0, None, None)) return lambda nodes1, connections1, batch_nodes2, batch_connections2: \ @@ -23,6 +29,89 @@ def create_distance_function(config, type: str): raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') +def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray, + connection2: NDArray, disjoint_coe: float = 1., compatibility_coe: float = 0.5): + """ + use in o2o distance. + o2o can't use vmap, numpy should be faster than jax function + :param nodes1: + :param connection1: + :param nodes2: + :param connection2: + :param disjoint_coe: + :param compatibility_coe: + :return: + """ + def analysis(nodes, connections): + nodes_dict = {} + idx2key = {} + for i, node in enumerate(nodes): + if np.isnan(node[0]): + continue + key = int(node[0]) + nodes_dict[key] = (node[1], node[2], node[3], node[4]) + idx2key[i] = key + + connections_dict = {} + for i in range(connections.shape[1]): + for j in range(connections.shape[2]): + if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]): + continue + key = (idx2key[i], idx2key[j]) + + weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None + enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None + connections_dict[key] = (weight, enabled) + + return nodes_dict, connections_dict + + nodes1, connections1 = analysis(nodes1, connection1) + nodes2, connections2 = analysis(nodes2, connection2) + + nd = 0.0 + if nodes1 or nodes2: # otherwise, both are empty + disjoint_nodes = 0 + for k2 in nodes2: + if k2 not in nodes1: + disjoint_nodes += 1 + + for k1, n1 in nodes1.items(): + n2 = nodes2.get(k1) + if n2 is None: + disjoint_nodes += 1 + else: + if np.isnan(n1[0]): # n1[1] is nan means input nodes + continue + d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1]) + d += 1 if n1[2] != n2[2] else 0 + d += 1 if n1[3] != n2[3] else 0 + nd += d + + max_nodes = max(len(nodes1), len(nodes2)) + nd = (compatibility_coe * nd + disjoint_coe * disjoint_nodes) / max_nodes + + cd = 0.0 + if connections1 or connections2: + disjoint_connections = 0 + for k2 in connections2: + if k2 not in connections1: + disjoint_connections += 1 + + for k1, c1 in connections1.items(): + c2 = connections2.get(k1) + if c2 is None: + disjoint_connections += 1 + else: + # Homologous genes compute their own distance value. + d = abs(c1[0] - c2[0]) + d += 1 if c1[1] != c2[1] else 0 + cd += d + max_conn = max(len(connections1), len(connections2)) + cd = (compatibility_coe * cd + disjoint_coe * disjoint_connections) / max_conn + + return nd + cd + + @jit def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., compatibility_coe: float = 0.5) -> Array: @@ -46,7 +135,7 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) - 2 + max_cnt = jnp.maximum(node_cnt1, node_cnt2) nodes = jnp.concatenate((nodes1, nodes2), axis=0) keys = nodes[:, 0] diff --git a/examples/xor.py b/examples/xor.py index e8ac80c..895bd5f 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]: return fitnesses.tolist() # returns a list -@using_cprofile -# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +# @using_cprofile +@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() pipeline = Pipeline(config, seed=11323) diff --git a/utils/default_config.json b/utils/default_config.json index 661c030..8ee9902 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -9,8 +9,8 @@ "population": { "fitness_criterion": "max", "fitness_threshold": 76, - "generation_limit": 1000, - "pop_size": 200, + "generation_limit": 100, + "pop_size": 1000, "reset_on_extinction": "False" }, "gene": {