move crossover_gene from ga.crossover to gene.basegene.

This commit is contained in:
wls2002
2024-05-30 15:06:08 +08:00
parent 9f6154d128
commit 20320105e6
4 changed files with 34 additions and 10 deletions

View File

@@ -10,6 +10,8 @@ class DefaultCrossover(BaseCrossover):
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey1, randkey2 = jax.random.split(randkey, 2)
randkeys1 = jax.random.split(randkey1, genome.max_nodes)
randkeys2 = jax.random.split(randkey2, genome.max_conns)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -21,7 +23,7 @@ class DefaultCrossover(BaseCrossover):
new_nodes = jnp.where(
jnp.isnan(nodes1) | jnp.isnan(nodes2),
nodes1,
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys1, nodes1, nodes2),
)
# crossover connections
@@ -31,7 +33,7 @@ class DefaultCrossover(BaseCrossover):
new_conns = jnp.where(
jnp.isnan(conns1) | jnp.isnan(conns2),
conns1,
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys2, conns1, conns2),
)
return new_nodes, new_conns
@@ -64,11 +66,3 @@ class DefaultCrossover(BaseCrossover):
)
return refactor_ar2
def crossover_gene(self, randkey, g1, g2, is_conn):
r = jax.random.uniform(randkey, shape=g1.shape)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene