diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 39bc0ce..b4e78fb 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -5,8 +5,6 @@ import jax from jax import jit, vmap, Array from jax import numpy as jnp -from .utils import flatten_connections, unflatten_connections - @jit def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ @@ -29,11 +27,9 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections - cons1 = flatten_connections(keys1, connections1) - cons2 = flatten_connections(keys2, connections2) - con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] - cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') - new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) + con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2] + connections2 = align_array(con_keys1, con_keys2, connections2, 'connection') + new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2)) new_cons = unflatten_connections(len(keys1), new_cons) return new_nodes, new_cons @@ -42,6 +38,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, @partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ + After I review this code, I found that it is the most difficult part of the code. Please never change it! make ar2 align with ar1. :param seq1: :param seq2: diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 6a34cbe..74fea62 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -3,17 +3,15 @@ Vectorization of genome representation. Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where: -1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes -too large to be represented by the current value of N. +1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes +too large to be represented by the current value of N and C. 2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function (act), and aggregation function (agg). -3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled -status. +3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled. Empty nodes or connections are represented using np.nan. """ from typing import Tuple, Dict -from functools import partial import jax import numpy as np @@ -22,13 +20,12 @@ from jax import numpy as jnp from jax import jit from jax import Array -from .utils import fetch_first - -EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) +from .utils import fetch_first, EMPTY_NODE def initialize_genomes(pop_size: int, N: int, + C: int, num_inputs: int, num_outputs: int, default_bias: float = 0.0, @@ -43,6 +40,7 @@ def initialize_genomes(pop_size: int, Args: pop_size (int): Number of genomes to initialize. N (int): Maximum number of nodes in the network. + C (int): Maximum number of connections in the network. num_inputs (int): Number of input nodes. num_outputs (int): Number of output nodes. default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0. @@ -60,9 +58,11 @@ def initialize_genomes(pop_size: int, # Reserve one row for potential mutation adding an extra node assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \ f"{num_inputs} and output_size: {num_outputs}!" + assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \ + f"{num_inputs} and output_size: {num_outputs}!" pop_nodes = np.full((pop_size, N, 5), np.nan) - pop_connections = np.full((pop_size, 2, N, N), np.nan) + pop_cons = np.full((pop_size, C, 4), np.nan) input_idx = np.arange(num_inputs) output_idx = np.arange(num_inputs, num_inputs + num_outputs) @@ -74,64 +74,69 @@ def initialize_genomes(pop_size: int, pop_nodes[:, output_idx, 3] = default_act pop_nodes[:, output_idx, 4] = default_agg - for i in input_idx: - for j in output_idx: - pop_connections[:, 0, i, j] = default_weight - pop_connections[:, 1, i, j] = 1 + grid_a, grid_b = np.meshgrid(input_idx, output_idx) + grid_a, grid_b = grid_a.flatten(), grid_b.flatten() - return pop_nodes, pop_connections, input_idx, output_idx + pop_cons[:, :num_inputs * num_outputs, 0] = grid_a + pop_cons[:, :num_inputs * num_outputs, 1] = grid_b + pop_cons[:, :num_inputs * num_outputs, 2] = default_weight + pop_cons[:, :num_inputs * num_outputs, 3] = 1 + + return pop_nodes, pop_cons, input_idx, output_idx -def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: +def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: """ Expand the genome to accommodate more nodes. :param pop_nodes: (pop_size, N, 5) - :param pop_connections: (pop_size, 2, N, N) + :param pop_cons: (pop_size, C, 4) :param new_N: + :param new_C: :return: """ - pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1] + pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1] new_pop_nodes = np.full((pop_size, new_N, 5), np.nan) new_pop_nodes[:, :old_N, :] = pop_nodes - new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan) - new_pop_connections[:, :, :old_N, :old_N] = pop_connections - return new_pop_nodes, new_pop_connections + new_pop_cons = np.full((pop_size, new_C, 4), np.nan) + new_pop_cons[:, :old_C, :] = pop_cons + + return new_pop_nodes, new_pop_cons -def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: +def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: """ Expand a single genome to accommodate more nodes. :param nodes: (N, 5) - :param connections: (2, N, N) + :param cons: (2, N, N) :param new_N: + :param new_C: :return: """ - old_N = nodes.shape[0] + old_N, old_C = nodes.shape[0], cons.shape[0] new_nodes = np.full((new_N, 5), np.nan) new_nodes[:old_N, :] = nodes - new_connections = np.full((2, new_N, new_N), np.nan) - new_connections[:, :old_N, :old_N] = connections + new_cons = np.full((new_C, 4), np.nan) + new_cons[:old_C, :] = cons - return new_nodes, new_connections + return new_nodes, new_cons -def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ +def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \ Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]: """ Convert a genome from array to dict. :param nodes: (N, 5) - :param connections: (2, N, N) + :param cons: (C, 4) :param output_keys: :param input_keys: - :return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)] + :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] """ # update nodes_dict try: nodes_dict = {} - idx2key = {} for i, node in enumerate(nodes): if np.isnan(node[0]): continue @@ -143,7 +148,6 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ act = node[3] if not np.isnan(node[3]) else None agg = node[4] if not np.isnan(node[4]) else None nodes_dict[key] = (bias, response, act, agg) - idx2key[i] = key # check nodes_dict for i in input_keys: @@ -162,117 +166,109 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ f"Normal node {k} must has non-None bias, response, act, or agg!" # update connections - connections_dict = {} - for i in range(connections.shape[1]): - for j in range(connections.shape[2]): - if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]): - continue - assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!" - assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!" - key = (idx2key[i], idx2key[j]) + cons_dict = {} + for i, con in enumerate(cons): + if np.isnan(con[0]): + continue + assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!" + i_key = int(con[0]) + o_key = int(con[1]) + assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" + assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" + key = (i_key, o_key) + weight = con[2] if not np.isnan(con[2]) else None + enabled = (con[3] == 1) if not np.isnan(con[3]) else None + assert weight is not None, f"Connection {key} must has non-None weight!" + assert enabled is not None, f"Connection {key} must has non-None enabled!" - weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None - enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None + cons_dict[key] = (weight, enabled) - assert weight is not None, f"Connection {key} must has non-None weight!" - assert enabled is not None, f"Connection {key} must has non-None enabled!" - connections_dict[key] = (weight, enabled) - - return nodes_dict, connections_dict + return nodes_dict, cons_dict except AssertionError: print(nodes) - print(connections) + print(cons) raise AssertionError -def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys): - pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections)) +def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys): res = [] - for nodes, connections in zip(pop_nodes, pop_connections): - res.append(analysis(nodes, connections, input_keys, output_keys)) + for nodes, cons in zip(pop_nodes, pop_cons): + res.append(analysis(nodes, cons, input_keys, output_keys)) return res @jit -def count(nodes, connections): +def count(nodes, cons): node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) - connections_cnt = jnp.sum(~jnp.isnan(connections[0, :, :])) - return node_cnt, connections_cnt + cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) + return node_cnt, cons_cnt @jit -def add_node(new_node_key: int, nodes: Array, connections: Array, +def add_node(nodes: Array, cons: Array, new_key: int, bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: """ add a new node to the genome. """ exist_keys = nodes[:, 0] idx = fetch_first(jnp.isnan(exist_keys)) - nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg])) - return nodes, connections + nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg])) + return nodes, cons @jit -def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: +def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]: """ delete a node from the genome. only delete the node, regardless of connections. """ node_keys = nodes[:, 0] idx = fetch_first(node_keys == node_key) - return delete_node_by_idx(idx, nodes, connections) + return delete_node_by_idx(nodes, cons, idx) @jit -def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: +def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]: """ - delete a node from the genome. only delete the node, regardless of connections. + use idx to delete a node from the genome. only delete the node, regardless of connections. """ - # node_keys = nodes[:, 0] nodes = nodes.at[idx].set(EMPTY_NODE) - # move the last node to the deleted node's position - # last_real_idx = fetch_last(~jnp.isnan(node_keys)) - # nodes = nodes.at[idx].set(nodes[last_real_idx]) - # nodes = nodes.at[last_real_idx].set(EMPTY_NODE) - return nodes, connections + return nodes, cons @jit -def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array, +def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int, weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]: """ add a new connection to the genome. """ - node_keys = nodes[:, 0] - from_idx = fetch_first(node_keys == from_node) - to_idx = fetch_first(node_keys == to_node) - return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled) + con_keys = cons[:, 0] + idx = fetch_first(jnp.isnan(con_keys)) + return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled) @jit -def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array, +def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int, weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: """ - add a new connection to the genome. + use idx to add a new connection to the genome. """ - connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled])) - return nodes, connections + cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled])) + return nodes, cons @jit -def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: +def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]: """ delete a connection from the genome. """ - node_keys = nodes[:, 0] - from_idx = fetch_first(node_keys == from_node) - to_idx = fetch_first(node_keys == to_node) - return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + return delete_connection_by_idx(nodes, cons, idx) @jit -def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: +def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]: """ - delete a connection from the genome. + use idx to delete a connection from the genome. """ - connections = connections.at[:, from_idx, to_idx].set(np.nan) - return nodes, connections + cons = cons.at[idx].set(np.nan) + return nodes, cons diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 29526a8..f052efc 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,5 +1,11 @@ -import jax.numpy as jnp +import numpy as np -EMPTY_NODE = jnp.full((1, 5), jnp.nan) +# 输入 +a = np.array([1, 2, 3, 4]) +b = np.array([5, 6]) -print(EMPTY_NODE) \ No newline at end of file +# 创建一个网格,其中包含所有可能的组合 +aa, bb = np.meshgrid(a, b) +aa = aa.flatten() +bb = bb.flatten() +print(aa, bb) \ No newline at end of file