202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
"""
|
|
Vectorization of genome representation.
|
|
|
|
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
|
|
|
|
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
|
|
too large to be represented by the current value of N and C.
|
|
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
|
|
(act), and aggregation function (agg).
|
|
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
|
|
Empty nodes or connections are represented using np.nan.
|
|
|
|
"""
|
|
from typing import Tuple, Dict
|
|
|
|
import jax
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from jax import numpy as jnp
|
|
from jax import jit
|
|
from jax import Array
|
|
|
|
from .utils import fetch_first
|
|
|
|
|
|
def initialize_genomes(pop_size: int,
|
|
N: int,
|
|
C: int,
|
|
num_inputs: int,
|
|
num_outputs: int,
|
|
default_bias: float = 0.0,
|
|
default_response: float = 1.0,
|
|
default_act: int = 0,
|
|
default_agg: int = 0,
|
|
default_weight: float = 0.0) \
|
|
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
|
|
"""
|
|
Initialize genomes with default values.
|
|
|
|
Args:
|
|
pop_size (int): Number of genomes to initialize.
|
|
N (int): Maximum number of nodes in the network.
|
|
C (int): Maximum number of connections in the network.
|
|
num_inputs (int): Number of input nodes.
|
|
num_outputs (int): Number of output nodes.
|
|
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
|
|
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
|
|
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
|
|
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
|
|
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
|
|
|
|
Raises:
|
|
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
|
|
|
|
Returns:
|
|
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
|
|
"""
|
|
# Reserve one row for potential mutation adding an extra node
|
|
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
|
|
f"{num_inputs} and output_size: {num_outputs}!"
|
|
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
|
|
f"{num_inputs} and output_size: {num_outputs}!"
|
|
|
|
pop_nodes = np.full((pop_size, N, 5), np.nan)
|
|
pop_cons = np.full((pop_size, C, 4), np.nan)
|
|
input_idx = np.arange(num_inputs)
|
|
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
|
|
|
pop_nodes[:, input_idx, 0] = input_idx
|
|
pop_nodes[:, output_idx, 0] = output_idx
|
|
|
|
pop_nodes[:, output_idx, 1] = default_bias
|
|
pop_nodes[:, output_idx, 2] = default_response
|
|
pop_nodes[:, output_idx, 3] = default_act
|
|
pop_nodes[:, output_idx, 4] = default_agg
|
|
|
|
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
|
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
|
|
|
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
|
|
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
|
|
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
|
|
pop_cons[:, :num_inputs * num_outputs, 3] = 1
|
|
|
|
return pop_nodes, pop_cons, input_idx, output_idx
|
|
|
|
|
|
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
|
"""
|
|
Expand the genome to accommodate more nodes.
|
|
:param pop_nodes: (pop_size, N, 5)
|
|
:param pop_cons: (pop_size, C, 4)
|
|
:param new_N:
|
|
:param new_C:
|
|
:return:
|
|
"""
|
|
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
|
|
|
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
|
new_pop_nodes[:, :old_N, :] = pop_nodes
|
|
|
|
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
|
new_pop_cons[:, :old_C, :] = pop_cons
|
|
|
|
return new_pop_nodes, new_pop_cons
|
|
|
|
|
|
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
|
"""
|
|
Expand a single genome to accommodate more nodes.
|
|
:param nodes: (N, 5)
|
|
:param cons: (2, N, N)
|
|
:param new_N:
|
|
:param new_C:
|
|
:return:
|
|
"""
|
|
old_N, old_C = nodes.shape[0], cons.shape[0]
|
|
new_nodes = np.full((new_N, 5), np.nan)
|
|
new_nodes[:old_N, :] = nodes
|
|
|
|
new_cons = np.full((new_C, 4), np.nan)
|
|
new_cons[:old_C, :] = cons
|
|
|
|
return new_nodes, new_cons
|
|
|
|
|
|
@jit
|
|
def count(nodes, cons):
|
|
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
|
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
|
return node_cnt, cons_cnt
|
|
|
|
|
|
@jit
|
|
def add_node(nodes: Array, cons: Array, new_key: int,
|
|
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
|
"""
|
|
add a new node to the genome.
|
|
"""
|
|
exist_keys = nodes[:, 0]
|
|
idx = fetch_first(jnp.isnan(exist_keys))
|
|
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
|
|
return nodes, cons
|
|
|
|
|
|
@jit
|
|
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
|
|
"""
|
|
delete a node from the genome. only delete the node, regardless of connections.
|
|
"""
|
|
node_keys = nodes[:, 0]
|
|
idx = fetch_first(node_keys == node_key)
|
|
return delete_node_by_idx(nodes, cons, idx)
|
|
|
|
|
|
@jit
|
|
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
|
"""
|
|
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
|
"""
|
|
nodes = nodes.at[idx].set(np.nan)
|
|
return nodes, cons
|
|
|
|
|
|
@jit
|
|
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
|
|
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
|
|
"""
|
|
add a new connection to the genome.
|
|
"""
|
|
con_keys = cons[:, 0]
|
|
idx = fetch_first(jnp.isnan(con_keys))
|
|
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
|
|
|
|
|
|
@jit
|
|
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
|
|
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
|
"""
|
|
use idx to add a new connection to the genome.
|
|
"""
|
|
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
|
|
return nodes, cons
|
|
|
|
|
|
@jit
|
|
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
|
|
"""
|
|
delete a connection from the genome.
|
|
"""
|
|
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
|
return delete_connection_by_idx(nodes, cons, idx)
|
|
|
|
|
|
@jit
|
|
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
|
"""
|
|
use idx to delete a connection from the genome.
|
|
"""
|
|
cons = cons.at[idx].set(np.nan)
|
|
return nodes, cons
|