refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

use "cons" to replace "connections" in code for beauty
This commit is contained in:
wls2002
2023-05-11 19:49:19 +08:00
parent e2a5117554
commit e5fc1167d9
3 changed files with 94 additions and 95 deletions

View File

@@ -5,8 +5,6 @@ import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
@jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
@@ -29,11 +27,9 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons
@@ -42,6 +38,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
@partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:

View File

@@ -3,17 +3,15 @@ Vectorization of genome representation.
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes
too large to be represented by the current value of N.
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
too large to be represented by the current value of N and C.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg).
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled
status.
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple, Dict
from functools import partial
import jax
import numpy as np
@@ -22,13 +20,12 @@ from jax import numpy as jnp
from jax import jit
from jax import Array
from .utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
from .utils import fetch_first, EMPTY_NODE
def initialize_genomes(pop_size: int,
N: int,
C: int,
num_inputs: int,
num_outputs: int,
default_bias: float = 0.0,
@@ -43,6 +40,7 @@ def initialize_genomes(pop_size: int,
Args:
pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
@@ -60,9 +58,11 @@ def initialize_genomes(pop_size: int,
# Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_connections = np.full((pop_size, 2, N, N), np.nan)
pop_cons = np.full((pop_size, C, 4), np.nan)
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
@@ -74,64 +74,69 @@ def initialize_genomes(pop_size: int,
pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg
for i in input_idx:
for j in output_idx:
pop_connections[:, 0, i, j] = default_weight
pop_connections[:, 1, i, j] = 1
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
return pop_nodes, pop_connections, input_idx, output_idx
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
pop_cons[:, :num_inputs * num_outputs, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the genome to accommodate more nodes.
:param pop_nodes: (pop_size, N, 5)
:param pop_connections: (pop_size, 2, N, N)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return:
"""
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1]
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan)
new_pop_connections[:, :, :old_N, :old_N] = pop_connections
return new_pop_nodes, new_pop_connections
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param cons: (2, N, N)
:param new_N:
:param new_C:
:return:
"""
old_N = nodes.shape[0]
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_connections = np.full((2, new_N, new_N), np.nan)
new_connections[:, :old_N, :old_N] = connections
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_connections
return new_nodes, new_cons
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
@@ -143,7 +148,6 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
idx2key[i] = key
# check nodes_dict
for i in input_keys:
@@ -162,117 +166,109 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
key = (idx2key[i], idx2key[j])
cons_dict = {}
for i, con in enumerate(cons):
if np.isnan(con[0]):
continue
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
i_key = int(con[0])
o_key = int(con[1])
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
key = (i_key, o_key)
weight = con[2] if not np.isnan(con[2]) else None
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
cons_dict[key] = (weight, enabled)
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
return nodes_dict, cons_dict
except AssertionError:
print(nodes)
print(connections)
print(cons)
raise AssertionError
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
res.append(analysis(nodes, connections, input_keys, output_keys))
for nodes, cons in zip(pop_nodes, pop_cons):
res.append(analysis(nodes, cons, input_keys, output_keys))
return res
@jit
def count(nodes, connections):
def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
connections_cnt = jnp.sum(~jnp.isnan(connections[0, :, :]))
return node_cnt, connections_cnt
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(new_node_key: int, nodes: Array, connections: Array,
def add_node(nodes: Array, cons: Array, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
"""
add a new node to the genome.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg]))
return nodes, connections
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(idx, nodes, connections)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
use idx to delete a node from the genome. only delete the node, regardless of connections.
"""
# node_keys = nodes[:, 0]
nodes = nodes.at[idx].set(EMPTY_NODE)
# move the last node to the deleted node's position
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
# nodes = nodes.at[idx].set(nodes[last_real_idx])
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
return nodes, connections
return nodes, cons
@jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
@jit
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array,
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
use idx to add a new connection to the genome.
"""
connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled]))
return nodes, connections
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
use idx to delete a connection from the genome.
"""
connections = connections.at[:, from_idx, to_idx].set(np.nan)
return nodes, connections
cons = cons.at[idx].set(np.nan)
return nodes, cons