112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
from typing import Type, Tuple
|
|
|
|
import numpy as np
|
|
import jax
|
|
from jax import Array, numpy as jnp
|
|
|
|
from algorithm import State
|
|
from ..gene import BaseGene
|
|
from algorithm.utils import fetch_first
|
|
|
|
|
|
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
|
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
|
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
|
|
|
input_idx = state.input_idx
|
|
output_idx = state.output_idx
|
|
new_node_key = max([*input_idx, *output_idx]) + 1
|
|
|
|
o_nodes[input_idx, 0] = input_idx
|
|
o_nodes[output_idx, 0] = output_idx
|
|
o_nodes[new_node_key, 0] = new_node_key
|
|
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = jax.device_get(gene_type.new_node_attrs(state))
|
|
o_nodes[new_node_key, 1:] = jax.device_get(gene_type.new_node_attrs(state))
|
|
|
|
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
|
|
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
|
o_conns[input_idx, 2] = True # enabled
|
|
o_conns[input_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
|
|
|
|
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
|
|
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
|
o_conns[output_idx, 2] = True # enabled
|
|
o_conns[output_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
|
|
|
|
# repeat origin genome for P times to create population
|
|
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
|
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
|
|
|
return jax.device_put([pop_nodes, pop_conns])
|
|
|
|
|
|
def count(nodes: Array, cons: Array):
|
|
"""
|
|
Count how many nodes and connections are in the genome.
|
|
"""
|
|
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
|
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
|
return node_cnt, cons_cnt
|
|
|
|
|
|
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
|
"""
|
|
Add a new node to the genome.
|
|
The new node will place at the first NaN row.
|
|
"""
|
|
exist_keys = nodes[:, 0]
|
|
idx = fetch_first(jnp.isnan(exist_keys))
|
|
nodes = nodes.at[idx, 0].set(new_key)
|
|
nodes = nodes.at[idx, 1:].set(attrs)
|
|
return nodes, cons
|
|
|
|
|
|
def delete_node(nodes: Array, cons: Array, node_key: Array) -> Tuple[Array, Array]:
|
|
"""
|
|
Delete a node from the genome. Only delete the node, regardless of connections.
|
|
Delete the node by its key.
|
|
"""
|
|
node_keys = nodes[:, 0]
|
|
idx = fetch_first(node_keys == node_key)
|
|
return delete_node_by_idx(nodes, cons, idx)
|
|
|
|
|
|
def delete_node_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
|
|
"""
|
|
Delete a node from the genome. Only delete the node, regardless of connections.
|
|
Delete the node by its idx.
|
|
"""
|
|
nodes = nodes.at[idx].set(np.nan)
|
|
return nodes, cons
|
|
|
|
|
|
def add_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array, enable: bool, attrs: Array) -> Tuple[
|
|
Array, Array]:
|
|
"""
|
|
Add a new connection to the genome.
|
|
The new connection will place at the first NaN row.
|
|
"""
|
|
con_keys = cons[:, 0]
|
|
idx = fetch_first(jnp.isnan(con_keys))
|
|
cons = cons.at[idx, 0:3].set(jnp.array([i_key, o_key, enable]))
|
|
cons = cons.at[idx, 3:].set(attrs)
|
|
return nodes, cons
|
|
|
|
|
|
def delete_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array) -> Tuple[Array, Array]:
|
|
"""
|
|
Delete a connection from the genome.
|
|
Delete the connection by its input and output node keys.
|
|
"""
|
|
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
|
return delete_connection_by_idx(nodes, cons, idx)
|
|
|
|
|
|
def delete_connection_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
|
|
"""
|
|
Delete a connection from the genome.
|
|
Delete the connection by its idx.
|
|
"""
|
|
cons = cons.at[idx].set(np.nan)
|
|
return nodes, cons
|