Files
tensorneat-mend/algorithms/neat/genome/crossover.py
2023-05-06 18:33:30 +08:00

152 lines
4.6 KiB
Python

from functools import partial
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
"""
crossover a batch of genomes
:param randkeys: batches of random keys
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
@jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param connections1:
:param nodes2:
:param connections2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
@jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')