Files
tensorneat-mend/algorithms/neat/genome/numpy/crossover.py
2023-05-06 21:04:28 +08:00

91 lines
2.8 KiB
Python

from typing import Tuple
import numpy as np
from numpy.typing import NDArray
from .utils import flatten_connections, unflatten_connections
def batch_crossover(batch_nodes1: NDArray, batch_connections1: NDArray, batch_nodes2: NDArray,
batch_connections2: NDArray) -> Tuple[NDArray, NDArray]:
"""
crossover a batch of genomes
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
res_nodes, res_cons = [], []
for (n1, c1, n2, c2) in zip(batch_nodes1, batch_connections1, batch_nodes2, batch_connections2):
new_nodes, new_cons = crossover(n1, c1, n2, c2)
res_nodes.append(new_nodes)
res_cons.append(new_cons)
return np.stack(res_nodes, axis=0), np.stack(res_cons, axis=0)
def crossover(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) \
-> Tuple[NDArray, NDArray]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param nodes1:
:param connections1:
:param nodes2:
:param connections2:
:return:
"""
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = np.where(np.isnan(nodes1) | np.isnan(nodes2), nodes1, crossover_gene(nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = np.where(np.isnan(cons1) | np.isnan(cons2), cons1, crossover_gene(cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons
def align_array(seq1: NDArray, seq2: NDArray, ar2: NDArray, gene_type: str) -> NDArray:
"""
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, np.newaxis], seq2[np.newaxis, :]
mask = (seq1 == seq2) & (~np.isnan(seq1))
if gene_type == 'connection':
mask = np.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = np.arange(0, len(seq1))
idx_fixed = np.dot(mask, idx)
refactor_ar2 = np.where(intersect_mask[:, np.newaxis], ar2[idx_fixed], np.nan)
return refactor_ar2
def crossover_gene(g1: NDArray, g2: NDArray) -> NDArray:
"""
crossover two genes
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = np.random.rand()
return np.where(r > 0.5, g1, g2)