From 35b095ba745a903faf3e210cf03129df6b0c2c2e Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 19 Jun 2023 17:32:34 +0800 Subject: [PATCH] modifying --- configs/configer.py | 10 +- neat/genome/crossover_.py | 17 +- neat/genome/distance_.py | 66 ++++--- neat/genome/genome_.py | 11 +- neat/genome/mutate_.py | 362 ++++++++++++++++++++++++++++++++++++++ neat/pipeline_.py | 4 +- 6 files changed, 428 insertions(+), 42 deletions(-) create mode 100644 neat/genome/mutate_.py diff --git a/configs/configer.py b/configs/configer.py index 6e8f269..b433256 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -2,11 +2,15 @@ import os import warnings import configparser +import numpy as np + from .activations import refactor_act from .aggregations import refactor_agg # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ + "input_idx", + "output_idx", "compatibility_disjoint", "compatibility_weight", "conn_add_prob", @@ -88,10 +92,14 @@ class Configer: refactor_act(config) refactor_agg(config) - + input_idx = np.arange(config['num_inputs']) + output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) + config['input_idx'] = input_idx + config['output_idx'] = output_idx return config @classmethod def create_jit_config(cls, config): jit_config = {k: config[k] for k in jit_config_keys} + return jit_config diff --git a/neat/genome/crossover_.py b/neat/genome/crossover_.py index 0873b98..2a02d9b 100644 --- a/neat/genome/crossover_.py +++ b/neat/genome/crossover_.py @@ -1,14 +1,17 @@ -from functools import partial +""" +Crossover two genomes to generate a new genome. +The calculation method is the same as the crossover operation in NEAT-python. +See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover +""" from typing import Tuple import jax -from jax import jit, vmap, Array +from jax import jit, Array from jax import numpy as jnp @jit -def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \ - -> Tuple[Array, Array]: +def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]: """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) @@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + # make homologous genes align in nodes2 align with nodes1 nodes2 = align_array(keys1, keys2, nodes2, 'node') + + # For not homologous genes, use the value of nodes1(winner) + # For homologous genes, use the crossover result between nodes1 and nodes2 new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections @@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: return new_nodes, new_cons -# @partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ After I review this code, I found that it is the most difficult part of the code. Please never change it! @@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: return refactor_ar2 -# @jit def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: """ crossover two genes diff --git a/neat/genome/distance_.py b/neat/genome/distance_.py index b85b6c7..69e421e 100644 --- a/neat/genome/distance_.py +++ b/neat/genome/distance_.py @@ -1,6 +1,7 @@ """ Calculate the distance between two genomes. The calculation method is the same as the distance calculation in NEAT-python. +See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py """ from typing import Dict @@ -14,6 +15,13 @@ from .utils import EMPTY_NODE, EMPTY_CON def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array: """ Calculate the distance between two genomes. + args: + nodes1: Array(N, 5) + cons1: Array(C, 4) + nodes2: Array(N, 5) + cons2: Array(C, 4) + returns: + distance: Array(, ) """ nd = node_distance(nodes1, nodes2, jit_config) # node distance cd = connection_distance(cons1, cons2, jit_config) # connection distance @@ -23,13 +31,15 @@ def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_confi @jit def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): """ - Calculate the distance between two nodes. + Calculate the distance between nodes of two genomes. """ - + # statistics nodes count of two genomes node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) max_cnt = jnp.maximum(node_cnt1, node_cnt2) + # align homologous nodes + # this process is similar to np.intersect1d. nodes = jnp.concatenate((nodes1, nodes2), axis=0) keys = nodes[:, 0] sorted_indices = jnp.argsort(keys, axis=0) @@ -37,21 +47,28 @@ def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end fr, sr = nodes[:-1], nodes[1:] # first row, second row + # flag location of homologous nodes intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + # calculate the count of non_homologous of two genomes non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - nd = batch_homologous_node_distance(fr, sr) - nd = jnp.where(jnp.isnan(nd), 0, nd) - homologous_distance = jnp.sum(nd * intersect_mask) - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - return jnp.where(max_cnt == 0, 0, val / max_cnt) + # calculate the distance of homologous nodes + hnd = vmap(homologous_node_distance)(fr, sr) + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) + + val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ + 'compatibility_weight'] + + return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division @jit -def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): +def connection_distance(cons1: Array, cons2: Array, jit_config: Dict): """ - Calculate the distance between two connections. + Calculate the distance between connections of two genomes. + Similar process as node_distance. """ con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) @@ -68,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - cd = batch_homologous_connection_distance(fr, sr) - cd = jnp.where(jnp.isnan(cd), 0, cd) - homologous_distance = jnp.sum(cd * intersect_mask) + hcd = vmap(homologous_connection_distance)(fr, sr) + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ + 'compatibility_weight'] return jnp.where(max_cnt == 0, 0, val / max_cnt) -@vmap -def batch_homologous_node_distance(b_n1, b_n2): - return homologous_node_distance(b_n1, b_n2) - - -@vmap -def batch_homologous_connection_distance(b_c1, b_c2): - return homologous_connection_distance(b_c1, b_c2) - - @jit -def homologous_node_distance(n1, n2): +def homologous_node_distance(n1: Array, n2: Array): + """ + Calculate the distance between two homologous nodes. + """ d = 0 d += jnp.abs(n1[1] - n2[1]) # bias d += jnp.abs(n1[2] - n2[2]) # response d += n1[3] != n2[3] # activation - d += n1[4] != n2[4] + d += n1[4] != n2[4] # aggregation return d @jit -def homologous_connection_distance(c1, c2): +def homologous_connection_distance(c1: Array, c2: Array): + """ + Calculate the distance between two homologous connections. + """ d = 0 d += jnp.abs(c1[2] - c2[2]) # weight d += c1[3] != c2[3] # enable diff --git a/neat/genome/genome_.py b/neat/genome/genome_.py index 7b61131..832de39 100644 --- a/neat/genome/genome_.py +++ b/neat/genome/genome_.py @@ -17,10 +17,7 @@ from jax import jit, numpy as jnp from .utils import fetch_first -def initialize_genomes(N: int, - C: int, - config: Dict) \ - -> Tuple[NDArray, NDArray, NDArray, NDArray]: +def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: """ Initialize genomes with default values. @@ -41,8 +38,8 @@ def initialize_genomes(N: int, pop_nodes = np.full((config['pop_size'], N, 5), np.nan) pop_cons = np.full((config['pop_size'], C, 4), np.nan) - input_idx = np.arange(config['num_inputs']) - output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) + input_idx = config['input_idx'] + output_idx = config['output_idx'] pop_nodes[:, input_idx, 0] = input_idx pop_nodes[:, output_idx, 0] = output_idx @@ -61,7 +58,7 @@ def initialize_genomes(N: int, pop_cons[:, :p, 2] = config['weight_init_mean'] pop_cons[:, :p, 3] = 1 - return pop_nodes, pop_cons, input_idx, output_idx + return pop_nodes, pop_cons def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: diff --git a/neat/genome/mutate_.py b/neat/genome/mutate_.py new file mode 100644 index 0000000..384cc67 --- /dev/null +++ b/neat/genome/mutate_.py @@ -0,0 +1,362 @@ +""" +Mutate a genome. +The calculation method is the same as the mutation operation in NEAT-python. +See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate +""" +from typing import Tuple, Dict +from functools import partial + +import jax +from jax import numpy as jnp +from jax import jit, Array + +from .utils import fetch_random, fetch_first, I_INT, unflatten_connections +from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection +from .graph import check_cycles + + +@jit +def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict): + """ + :param rand_key: + :param nodes: (N, 5) + :param connections: (2, N, N) + :param new_node_key: + :param jit_config: + :return: + """ + + def m_add_node(rk, n, c): + return mutate_add_node(rk, n, c, new_node_key, jit_config['bias_init_mean'], jit_config['response_init_mean'], + jit_config['activation_default'], jit_config['aggregation_default']) + + def m_add_connection(rk, n, c): + return mutate_add_connection(rk, n, c, jit_config['input_idx'], jit_config['output_idx']) + + def m_delete_node(rk, n, c): + return mutate_delete_node(rk, n, c, jit_config['input_idx'], jit_config['output_idx']) + + def m_delete_connection(rk, n, c): + return mutate_delete_connection(rk, n, c) + + r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) + + # structural mutations + # mutate add node + r = rand(r1) + aux_nodes, aux_connections = m_add_node(r1, nodes, connections) + nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections) + + # mutate add connection + r = rand(r2) + aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) + nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections) + + # mutate delete node + r = rand(r3) + aux_nodes, aux_connections = m_delete_node(r2, nodes, connections) + nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections) + + # mutate delete connection + r = rand(r4) + aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections) + nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes) + connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections) + + # value mutations + nodes, connections = mutate_values(rand_key, nodes, connections, jit_config) + + return nodes, connections + + +@jit +def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: + """ + Mutate values of nodes and connections. + + Args: + rand_key: A random key for generating random values. + nodes: A 2D array representing nodes. + cons: A 3D array representing connections. + jit_config: A dict containing configuration for jit-able functions. + + Returns: + A tuple containing mutated nodes and connections. + """ + + k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6) + bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std, + bias_mutate_strength, bias_mutate_rate, bias_replace_rate) + response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate) + weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std, + weight_mutate_strength, weight_mutate_rate, weight_replace_rate) + act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate) + agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate) + + # mutate enabled + r = jax.random.uniform(rand_key, cons[:, 3].shape) + enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3]) + enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan) + + nodes = nodes.at[:, 1].set(bias_new) + nodes = nodes.at[:, 2].set(response_new) + nodes = nodes.at[:, 3].set(act_new) + nodes = nodes.at[:, 4].set(agg_new) + cons = cons.at[:, 2].set(weight_new) + cons = cons.at[:, 3].set(enabled_new) + return nodes, cons + + +@jit +def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, + mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: + """ + Mutate float values of a given array. + + Args: + rand_key: A random key for generating random values. + old_vals: A 1D array of float values to be mutated. + mean: Mean of the values. + std: Standard deviation of the values. + mutate_strength: Strength of the mutation. + mutate_rate: Rate of the mutation. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of float values. + """ + k1, k2, k3, rand_key = jax.random.split(rand_key, num=4) + noise = jax.random.normal(k1, old_vals.shape) * mutate_strength + replace = jax.random.normal(k2, old_vals.shape) * std + mean + r = jax.random.uniform(k3, old_vals.shape) + new_vals = old_vals + new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals) + new_vals = jnp.where( + jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate), + replace, + new_vals + ) + new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) + return new_vals + + +@jit +def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: + """ + Mutate integer values (act, agg) of a given array. + + Args: + rand_key: A random key for generating random values. + old_vals: A 1D array of integer values to be mutated. + val_list: List of the integer values. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of integer values. + """ + k1, k2, rand_key = jax.random.split(rand_key, num=3) + replace_val = jax.random.choice(k1, val_list, old_vals.shape) + r = jax.random.uniform(k2, old_vals.shape) + new_vals = old_vals + new_vals = jnp.where(r < replace_rate, replace_val, new_vals) + new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) + return new_vals + + +@jit +def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int, + default_bias: float = 0, default_response: float = 1, + default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: + """ + Randomly add a new node from splitting a connection. + :param rand_key: + :param new_node_key: + :param nodes: + :param cons: + :param default_bias: + :param default_response: + :param default_act: + :param default_agg: + :return: + """ + # randomly choose a connection + i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) + + def nothing(): # there is no connection to split + return nodes, cons + + def successful_add_node(): + # disable the connection + new_nodes, new_cons = nodes, cons + new_cons = new_cons.at[idx, 3].set(False) + + # add a new node + new_nodes, new_cons = \ + add_node(new_nodes, new_cons, new_node_key, + bias=default_bias, response=default_response, act=default_act, agg=default_agg) + + # add two new connections + w = new_cons[idx, 2] + new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True) + new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True) + return new_nodes, new_cons + + # if from_idx == I_INT, that means no connection exist, do nothing + nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node) + + return nodes, cons + + +# TODO: Need we really need to delete a node? +@jit +def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, + input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: + """ + Randomly delete a node. Input and output nodes are not allowed to be deleted. + :param rand_key: + :param nodes: + :param cons: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose a node + node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=False) + + def nothing(): + return nodes, cons + + def successful_delete_node(): + # delete the node + aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx) + + # delete all connections + aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis], + jnp.nan, aux_cons) + + return aux_nodes, aux_cons + + nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) + + return nodes, cons + + +@jit +def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, + input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: + """ + Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, + cycles are not allowed. + :param rand_key: + :param nodes: + :param cons: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose two nodes + k1, k2 = jax.random.split(rand_key, num=2) + i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, + allow_input_keys=True, allow_output_keys=True) + o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=True) + + con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + + def successful(): + new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True) + return new_nodes, new_cons + + def already_exist(): + new_cons = cons.at[con_idx, 3].set(True) + return nodes, new_cons + + def cycle(): + return nodes, cons + + is_already_exist = con_idx != I_INT + unflattened = unflatten_connections(nodes, cons) + is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx) + + choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) + nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) + return nodes, cons + + +@jit +def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): + """ + Randomly delete a connection. + :param rand_key: + :param nodes: + :param cons: + :return: + """ + # randomly choose a connection + i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) + + def nothing(): + return nodes, cons + + def successfully_delete_connection(): + return delete_connection_by_idx(nodes, cons, idx) + + nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) + + return nodes, cons + + +@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) +def choice_node_key(rand_key: Array, nodes: Array, + input_keys: Array, output_keys: Array, + allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param rand_key: + :param nodes: + :param input_keys: + :param output_keys: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~jnp.isnan(node_keys) + + if not allow_input_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) + + if not allow_output_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) + + idx = fetch_random(rand_key, mask) + key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + return key, idx + + +@jit +def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: + """ + Randomly choose a connection key from the given connections. + :param rand_key: + :param nodes: + :param cons: + :return: i_key, o_key, idx + """ + + idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) + i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) + + return i_key, o_key, idx + + +@jit +def rand(rand_key): + return jax.random.uniform(rand_key, ()) diff --git a/neat/pipeline_.py b/neat/pipeline_.py index 335efb5..adafe00 100644 --- a/neat/pipeline_.py +++ b/neat/pipeline_.py @@ -21,7 +21,7 @@ class Pipeline: self.generation = 0 self.best_genome = None - self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = initialize_genomes(self.N, self.C, self.config) + self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config) - print(self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx, sep='\n') + print(self.pop_nodes, self.pop_cons, sep='\n') print(self.jit_config)