modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.

This commit is contained in:
wls2002
2024-06-01 20:42:42 +08:00
parent 4ad9f0a85a
commit e65200a94e
14 changed files with 281 additions and 204 deletions

View File

@@ -1,6 +1,12 @@
import jax, jax.numpy as jnp
from .base import BaseCrossover
from utils.tools import (
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
)
class DefaultCrossover(BaseCrossover):
@@ -20,21 +26,33 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(
jnp.isnan(nodes1) | jnp.isnan(nodes2),
nodes1,
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys1, nodes1, nodes2),
node_attrs1 = jax.vmap(extract_node_attrs)(nodes1)
node_attrs2 = jax.vmap(extract_node_attrs)(nodes2)
new_node_attrs = jnp.where(
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys1, node_attrs1, node_attrs2
), # homologous or both nan
)
new_nodes = jax.vmap(set_node_attrs)(nodes1, new_node_attrs)
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(
jnp.isnan(conns1) | jnp.isnan(conns2),
conns1,
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys2, conns1, conns2),
conns_attrs1 = jax.vmap(extract_conn_attrs)(conns1)
conns_attrs2 = jax.vmap(extract_conn_attrs)(conns2)
new_conn_attrs = jnp.where(
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys2, conns_attrs1, conns_attrs2
), # homologous or both nan
)
new_conns = jax.vmap(set_conn_attrs)(conns1, new_conn_attrs)
return new_nodes, new_conns