91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from typing import Tuple
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from .utils import flatten_connections, unflatten_connections
|
|
|
|
|
|
def batch_crossover(batch_nodes1: NDArray, batch_connections1: NDArray, batch_nodes2: NDArray,
|
|
batch_connections2: NDArray) -> Tuple[NDArray, NDArray]:
|
|
"""
|
|
crossover a batch of genomes
|
|
:param batch_nodes1:
|
|
:param batch_connections1:
|
|
:param batch_nodes2:
|
|
:param batch_connections2:
|
|
:return:
|
|
"""
|
|
res_nodes, res_cons = [], []
|
|
for (n1, c1, n2, c2) in zip(batch_nodes1, batch_connections1, batch_nodes2, batch_connections2):
|
|
new_nodes, new_cons = crossover(n1, c1, n2, c2)
|
|
res_nodes.append(new_nodes)
|
|
res_cons.append(new_cons)
|
|
return np.stack(res_nodes, axis=0), np.stack(res_cons, axis=0)
|
|
|
|
|
|
def crossover(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) \
|
|
-> Tuple[NDArray, NDArray]:
|
|
"""
|
|
use genome1 and genome2 to generate a new genome
|
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
|
:param nodes1:
|
|
:param connections1:
|
|
:param nodes2:
|
|
:param connections2:
|
|
:return:
|
|
"""
|
|
|
|
# crossover nodes
|
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
|
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
|
new_nodes = np.where(np.isnan(nodes1) | np.isnan(nodes2), nodes1, crossover_gene(nodes1, nodes2))
|
|
|
|
# crossover connections
|
|
cons1 = flatten_connections(keys1, connections1)
|
|
cons2 = flatten_connections(keys2, connections2)
|
|
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
|
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
|
new_cons = np.where(np.isnan(cons1) | np.isnan(cons2), cons1, crossover_gene(cons1, cons2))
|
|
new_cons = unflatten_connections(len(keys1), new_cons)
|
|
|
|
return new_nodes, new_cons
|
|
|
|
|
|
def align_array(seq1: NDArray, seq2: NDArray, ar2: NDArray, gene_type: str) -> NDArray:
|
|
"""
|
|
make ar2 align with ar1.
|
|
:param seq1:
|
|
:param seq2:
|
|
:param ar2:
|
|
:param gene_type:
|
|
:return:
|
|
align means to intersect part of ar2 will be at the same position as ar1,
|
|
non-intersect part of ar2 will be set to Nan
|
|
"""
|
|
seq1, seq2 = seq1[:, np.newaxis], seq2[np.newaxis, :]
|
|
mask = (seq1 == seq2) & (~np.isnan(seq1))
|
|
|
|
if gene_type == 'connection':
|
|
mask = np.all(mask, axis=2)
|
|
|
|
intersect_mask = mask.any(axis=1)
|
|
idx = np.arange(0, len(seq1))
|
|
idx_fixed = np.dot(mask, idx)
|
|
|
|
refactor_ar2 = np.where(intersect_mask[:, np.newaxis], ar2[idx_fixed], np.nan)
|
|
|
|
return refactor_ar2
|
|
|
|
|
|
def crossover_gene(g1: NDArray, g2: NDArray) -> NDArray:
|
|
"""
|
|
crossover two genes
|
|
:param g1:
|
|
:param g2:
|
|
:return:
|
|
only gene with the same key will be crossover, thus don't need to consider change key
|
|
"""
|
|
r = np.random.rand()
|
|
return np.where(r > 0.5, g1, g2)
|