fix bug in crossover: the child from two normal networks should always be normal.

This commit is contained in:
wls2002
2024-05-22 10:27:32 +08:00
parent d1559317d1
commit 6a37563696
11 changed files with 46 additions and 43 deletions

View File

@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
return new_nodes, new_conns
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
def crossover_gene(self, rand_key, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene