refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

use "cons" to replace "connections" in code for beauty
This commit is contained in:
wls2002
2023-05-11 19:49:19 +08:00
parent e2a5117554
commit e5fc1167d9
3 changed files with 94 additions and 95 deletions

View File

@@ -5,8 +5,6 @@ import jax
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
@jit @jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
@@ -29,11 +27,9 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections # crossover connections
cons1 = flatten_connections(keys1, connections1) con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
cons2 = flatten_connections(keys2, connections2) connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons) new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons return new_nodes, new_cons
@@ -42,6 +38,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
@partial(jit, static_argnames=['gene_type']) @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
""" """
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1. make ar2 align with ar1.
:param seq1: :param seq1:
:param seq2: :param seq2:

View File

@@ -3,17 +3,15 @@ Vectorization of genome representation.
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where: Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes 1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
too large to be represented by the current value of N. too large to be represented by the current value of N and C.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function 2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg). (act), and aggregation function (agg).
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled 3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
status.
Empty nodes or connections are represented using np.nan. Empty nodes or connections are represented using np.nan.
""" """
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import partial
import jax import jax
import numpy as np import numpy as np
@@ -22,13 +20,12 @@ from jax import numpy as jnp
from jax import jit from jax import jit
from jax import Array from jax import Array
from .utils import fetch_first from .utils import fetch_first, EMPTY_NODE
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
def initialize_genomes(pop_size: int, def initialize_genomes(pop_size: int,
N: int, N: int,
C: int,
num_inputs: int, num_inputs: int,
num_outputs: int, num_outputs: int,
default_bias: float = 0.0, default_bias: float = 0.0,
@@ -43,6 +40,7 @@ def initialize_genomes(pop_size: int,
Args: Args:
pop_size (int): Number of genomes to initialize. pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network. N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
num_inputs (int): Number of input nodes. num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes. num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0. default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
@@ -60,9 +58,11 @@ def initialize_genomes(pop_size: int,
# Reserve one row for potential mutation adding an extra node # Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \ assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!" f"{num_inputs} and output_size: {num_outputs}!"
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan) pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_connections = np.full((pop_size, 2, N, N), np.nan) pop_cons = np.full((pop_size, C, 4), np.nan)
input_idx = np.arange(num_inputs) input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs) output_idx = np.arange(num_inputs, num_inputs + num_outputs)
@@ -74,64 +74,69 @@ def initialize_genomes(pop_size: int,
pop_nodes[:, output_idx, 3] = default_act pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg pop_nodes[:, output_idx, 4] = default_agg
for i in input_idx: grid_a, grid_b = np.meshgrid(input_idx, output_idx)
for j in output_idx: grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
pop_connections[:, 0, i, j] = default_weight
pop_connections[:, 1, i, j] = 1
return pop_nodes, pop_connections, input_idx, output_idx pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
pop_cons[:, :num_inputs * num_outputs, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
""" """
Expand the genome to accommodate more nodes. Expand the genome to accommodate more nodes.
:param pop_nodes: (pop_size, N, 5) :param pop_nodes: (pop_size, N, 5)
:param pop_connections: (pop_size, 2, N, N) :param pop_cons: (pop_size, C, 4)
:param new_N: :param new_N:
:param new_C:
:return: :return:
""" """
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1] pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan) new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan) new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_connections[:, :, :old_N, :old_N] = pop_connections new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_connections
return new_pop_nodes, new_pop_cons
def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
""" """
Expand a single genome to accommodate more nodes. Expand a single genome to accommodate more nodes.
:param nodes: (N, 5) :param nodes: (N, 5)
:param connections: (2, N, N) :param cons: (2, N, N)
:param new_N: :param new_N:
:param new_C:
:return: :return:
""" """
old_N = nodes.shape[0] old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan) new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes new_nodes[:old_N, :] = nodes
new_connections = np.full((2, new_N, new_N), np.nan) new_cons = np.full((new_C, 4), np.nan)
new_connections[:, :old_N, :old_N] = connections new_cons[:old_C, :] = cons
return new_nodes, new_connections return new_nodes, new_cons
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]: Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
""" """
Convert a genome from array to dict. Convert a genome from array to dict.
:param nodes: (N, 5) :param nodes: (N, 5)
:param connections: (2, N, N) :param cons: (C, 4)
:param output_keys: :param output_keys:
:param input_keys: :param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)] :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
""" """
# update nodes_dict # update nodes_dict
try: try:
nodes_dict = {} nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes): for i, node in enumerate(nodes):
if np.isnan(node[0]): if np.isnan(node[0]):
continue continue
@@ -143,7 +148,6 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
act = node[3] if not np.isnan(node[3]) else None act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg) nodes_dict[key] = (bias, response, act, agg)
idx2key[i] = key
# check nodes_dict # check nodes_dict
for i in input_keys: for i in input_keys:
@@ -162,117 +166,109 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
f"Normal node {k} must has non-None bias, response, act, or agg!" f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections # update connections
connections_dict = {} cons_dict = {}
for i in range(connections.shape[1]): for i, con in enumerate(cons):
for j in range(connections.shape[2]): if np.isnan(con[0]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue continue
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!" assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!" i_key = int(con[0])
key = (idx2key[i], idx2key[j]) o_key = int(con[1])
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None key = (i_key, o_key)
weight = con[2] if not np.isnan(con[2]) else None
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
assert weight is not None, f"Connection {key} must has non-None weight!" assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!" assert enabled is not None, f"Connection {key} must has non-None enabled!"
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict cons_dict[key] = (weight, enabled)
return nodes_dict, cons_dict
except AssertionError: except AssertionError:
print(nodes) print(nodes)
print(connections) print(cons)
raise AssertionError raise AssertionError
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys): def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
res = [] res = []
for nodes, connections in zip(pop_nodes, pop_connections): for nodes, cons in zip(pop_nodes, pop_cons):
res.append(analysis(nodes, connections, input_keys, output_keys)) res.append(analysis(nodes, cons, input_keys, output_keys))
return res return res
@jit @jit
def count(nodes, connections): def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
connections_cnt = jnp.sum(~jnp.isnan(connections[0, :, :])) cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, connections_cnt return node_cnt, cons_cnt
@jit @jit
def add_node(new_node_key: int, nodes: Array, connections: Array, def add_node(nodes: Array, cons: Array, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
""" """
add a new node to the genome. add a new node to the genome.
""" """
exist_keys = nodes[:, 0] exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys)) idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg])) nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, connections return nodes, cons
@jit @jit
def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
""" """
delete a node from the genome. only delete the node, regardless of connections. delete a node from the genome. only delete the node, regardless of connections.
""" """
node_keys = nodes[:, 0] node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key) idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(idx, nodes, connections) return delete_node_by_idx(nodes, cons, idx)
@jit @jit
def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
""" """
delete a node from the genome. only delete the node, regardless of connections. use idx to delete a node from the genome. only delete the node, regardless of connections.
""" """
# node_keys = nodes[:, 0]
nodes = nodes.at[idx].set(EMPTY_NODE) nodes = nodes.at[idx].set(EMPTY_NODE)
# move the last node to the deleted node's position return nodes, cons
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
# nodes = nodes.at[idx].set(nodes[last_real_idx])
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
return nodes, connections
@jit @jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array, def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]: weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
""" """
add a new connection to the genome. add a new connection to the genome.
""" """
node_keys = nodes[:, 0] con_keys = cons[:, 0]
from_idx = fetch_first(node_keys == from_node) idx = fetch_first(jnp.isnan(con_keys))
to_idx = fetch_first(node_keys == to_node) return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
@jit @jit
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array, def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
""" """
add a new connection to the genome. use idx to add a new connection to the genome.
""" """
connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled])) cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, connections return nodes, cons
@jit @jit
def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
""" """
delete a connection from the genome. delete a connection from the genome.
""" """
node_keys = nodes[:, 0] idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
from_idx = fetch_first(node_keys == from_node) return delete_connection_by_idx(nodes, cons, idx)
to_idx = fetch_first(node_keys == to_node)
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
@jit @jit
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
""" """
delete a connection from the genome. use idx to delete a connection from the genome.
""" """
connections = connections.at[:, from_idx, to_idx].set(np.nan) cons = cons.at[idx].set(np.nan)
return nodes, connections return nodes, cons

View File

@@ -1,5 +1,11 @@
import jax.numpy as jnp import numpy as np
EMPTY_NODE = jnp.full((1, 5), jnp.nan) # 输入
a = np.array([1, 2, 3, 4])
b = np.array([5, 6])
print(EMPTY_NODE) # 创建一个网格,其中包含所有可能的组合
aa, bb = np.meshgrid(a, b)
aa = aa.flatten()
bb = bb.flatten()
print(aa, bb)