refactor genome.py use (C, 4) to replace (2, N, N) to represent connections
use "cons" to replace "connections" in code for beauty
This commit is contained in:
@@ -5,8 +5,6 @@ import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
@@ -29,11 +27,9 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
|
||||
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||
|
||||
return new_nodes, new_cons
|
||||
@@ -42,6 +38,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
|
||||
@@ -3,17 +3,15 @@ Vectorization of genome representation.
|
||||
|
||||
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
|
||||
|
||||
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes
|
||||
too large to be represented by the current value of N.
|
||||
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
|
||||
too large to be represented by the current value of N and C.
|
||||
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
|
||||
(act), and aggregation function (agg).
|
||||
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled
|
||||
status.
|
||||
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
|
||||
Empty nodes or connections are represented using np.nan.
|
||||
|
||||
"""
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
@@ -22,13 +20,12 @@ from jax import numpy as jnp
|
||||
from jax import jit
|
||||
from jax import Array
|
||||
|
||||
from .utils import fetch_first
|
||||
|
||||
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||
from .utils import fetch_first, EMPTY_NODE
|
||||
|
||||
|
||||
def initialize_genomes(pop_size: int,
|
||||
N: int,
|
||||
C: int,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
default_bias: float = 0.0,
|
||||
@@ -43,6 +40,7 @@ def initialize_genomes(pop_size: int,
|
||||
Args:
|
||||
pop_size (int): Number of genomes to initialize.
|
||||
N (int): Maximum number of nodes in the network.
|
||||
C (int): Maximum number of connections in the network.
|
||||
num_inputs (int): Number of input nodes.
|
||||
num_outputs (int): Number of output nodes.
|
||||
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
|
||||
@@ -60,9 +58,11 @@ def initialize_genomes(pop_size: int,
|
||||
# Reserve one row for potential mutation adding an extra node
|
||||
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
|
||||
f"{num_inputs} and output_size: {num_outputs}!"
|
||||
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
|
||||
f"{num_inputs} and output_size: {num_outputs}!"
|
||||
|
||||
pop_nodes = np.full((pop_size, N, 5), np.nan)
|
||||
pop_connections = np.full((pop_size, 2, N, N), np.nan)
|
||||
pop_cons = np.full((pop_size, C, 4), np.nan)
|
||||
input_idx = np.arange(num_inputs)
|
||||
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
|
||||
@@ -74,64 +74,69 @@ def initialize_genomes(pop_size: int,
|
||||
pop_nodes[:, output_idx, 3] = default_act
|
||||
pop_nodes[:, output_idx, 4] = default_agg
|
||||
|
||||
for i in input_idx:
|
||||
for j in output_idx:
|
||||
pop_connections[:, 0, i, j] = default_weight
|
||||
pop_connections[:, 1, i, j] = 1
|
||||
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
||||
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
||||
|
||||
return pop_nodes, pop_connections, input_idx, output_idx
|
||||
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
|
||||
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
|
||||
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
|
||||
pop_cons[:, :num_inputs * num_outputs, 3] = 1
|
||||
|
||||
return pop_nodes, pop_cons, input_idx, output_idx
|
||||
|
||||
|
||||
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
|
||||
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand the genome to accommodate more nodes.
|
||||
:param pop_nodes: (pop_size, N, 5)
|
||||
:param pop_connections: (pop_size, 2, N, N)
|
||||
:param pop_cons: (pop_size, C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return:
|
||||
"""
|
||||
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1]
|
||||
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
||||
|
||||
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||
|
||||
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan)
|
||||
new_pop_connections[:, :, :old_N, :old_N] = pop_connections
|
||||
return new_pop_nodes, new_pop_connections
|
||||
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
||||
new_pop_cons[:, :old_C, :] = pop_cons
|
||||
|
||||
return new_pop_nodes, new_pop_cons
|
||||
|
||||
|
||||
def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
|
||||
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand a single genome to accommodate more nodes.
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param cons: (2, N, N)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return:
|
||||
"""
|
||||
old_N = nodes.shape[0]
|
||||
old_N, old_C = nodes.shape[0], cons.shape[0]
|
||||
new_nodes = np.full((new_N, 5), np.nan)
|
||||
new_nodes[:old_N, :] = nodes
|
||||
|
||||
new_connections = np.full((2, new_N, new_N), np.nan)
|
||||
new_connections[:, :old_N, :old_N] = connections
|
||||
new_cons = np.full((new_C, 4), np.nan)
|
||||
new_cons[:old_C, :] = cons
|
||||
|
||||
return new_nodes, new_connections
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
|
||||
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
|
||||
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
|
||||
"""
|
||||
Convert a genome from array to dict.
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param cons: (C, 4)
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
||||
"""
|
||||
# update nodes_dict
|
||||
try:
|
||||
nodes_dict = {}
|
||||
idx2key = {}
|
||||
for i, node in enumerate(nodes):
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
@@ -143,7 +148,6 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
|
||||
act = node[3] if not np.isnan(node[3]) else None
|
||||
agg = node[4] if not np.isnan(node[4]) else None
|
||||
nodes_dict[key] = (bias, response, act, agg)
|
||||
idx2key[i] = key
|
||||
|
||||
# check nodes_dict
|
||||
for i in input_keys:
|
||||
@@ -162,117 +166,109 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
|
||||
f"Normal node {k} must has non-None bias, response, act, or agg!"
|
||||
|
||||
# update connections
|
||||
connections_dict = {}
|
||||
for i in range(connections.shape[1]):
|
||||
for j in range(connections.shape[2]):
|
||||
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
|
||||
continue
|
||||
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
|
||||
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
|
||||
key = (idx2key[i], idx2key[j])
|
||||
cons_dict = {}
|
||||
for i, con in enumerate(cons):
|
||||
if np.isnan(con[0]):
|
||||
continue
|
||||
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
|
||||
i_key = int(con[0])
|
||||
o_key = int(con[1])
|
||||
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
||||
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
||||
key = (i_key, o_key)
|
||||
weight = con[2] if not np.isnan(con[2]) else None
|
||||
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
|
||||
assert weight is not None, f"Connection {key} must has non-None weight!"
|
||||
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
||||
|
||||
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
|
||||
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
|
||||
cons_dict[key] = (weight, enabled)
|
||||
|
||||
assert weight is not None, f"Connection {key} must has non-None weight!"
|
||||
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
||||
connections_dict[key] = (weight, enabled)
|
||||
|
||||
return nodes_dict, connections_dict
|
||||
return nodes_dict, cons_dict
|
||||
except AssertionError:
|
||||
print(nodes)
|
||||
print(connections)
|
||||
print(cons)
|
||||
raise AssertionError
|
||||
|
||||
|
||||
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
|
||||
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
|
||||
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
|
||||
res = []
|
||||
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||
res.append(analysis(nodes, connections, input_keys, output_keys))
|
||||
for nodes, cons in zip(pop_nodes, pop_cons):
|
||||
res.append(analysis(nodes, cons, input_keys, output_keys))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def count(nodes, connections):
|
||||
def count(nodes, cons):
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
connections_cnt = jnp.sum(~jnp.isnan(connections[0, :, :]))
|
||||
return node_cnt, connections_cnt
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
||||
def add_node(nodes: Array, cons: Array, new_key: int,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
||||
"""
|
||||
add a new node to the genome.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
idx = fetch_first(jnp.isnan(exist_keys))
|
||||
nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg]))
|
||||
return nodes, connections
|
||||
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
|
||||
"""
|
||||
delete a node from the genome. only delete the node, regardless of connections.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
idx = fetch_first(node_keys == node_key)
|
||||
return delete_node_by_idx(idx, nodes, connections)
|
||||
return delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
@jit
|
||||
def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
||||
"""
|
||||
delete a node from the genome. only delete the node, regardless of connections.
|
||||
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
||||
"""
|
||||
# node_keys = nodes[:, 0]
|
||||
nodes = nodes.at[idx].set(EMPTY_NODE)
|
||||
# move the last node to the deleted node's position
|
||||
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
||||
# nodes = nodes.at[idx].set(nodes[last_real_idx])
|
||||
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
|
||||
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
|
||||
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||
"""
|
||||
add a new connection to the genome.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
from_idx = fetch_first(node_keys == from_node)
|
||||
to_idx = fetch_first(node_keys == to_node)
|
||||
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
|
||||
con_keys = cons[:, 0]
|
||||
idx = fetch_first(jnp.isnan(con_keys))
|
||||
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
|
||||
|
||||
|
||||
@jit
|
||||
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array,
|
||||
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
|
||||
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||
"""
|
||||
add a new connection to the genome.
|
||||
use idx to add a new connection to the genome.
|
||||
"""
|
||||
connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled]))
|
||||
return nodes, connections
|
||||
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
|
||||
"""
|
||||
delete a connection from the genome.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
from_idx = fetch_first(node_keys == from_node)
|
||||
to_idx = fetch_first(node_keys == to_node)
|
||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
@jit
|
||||
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
||||
"""
|
||||
delete a connection from the genome.
|
||||
use idx to delete a connection from the genome.
|
||||
"""
|
||||
connections = connections.at[:, from_idx, to_idx].set(np.nan)
|
||||
return nodes, connections
|
||||
cons = cons.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
|
||||
Reference in New Issue
Block a user