modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
from utils.tools import (
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
@@ -20,21 +26,33 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(
|
||||
jnp.isnan(nodes1) | jnp.isnan(nodes2),
|
||||
nodes1,
|
||||
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys1, nodes1, nodes2),
|
||||
node_attrs1 = jax.vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = jax.vmap(extract_node_attrs)(nodes2)
|
||||
|
||||
new_node_attrs = jnp.where(
|
||||
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
|
||||
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
|
||||
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys1, node_attrs1, node_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_nodes = jax.vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(
|
||||
jnp.isnan(conns1) | jnp.isnan(conns2),
|
||||
conns1,
|
||||
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys2, conns1, conns2),
|
||||
conns_attrs1 = jax.vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = jax.vmap(extract_conn_attrs)(conns2)
|
||||
|
||||
new_conn_attrs = jnp.where(
|
||||
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
|
||||
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
|
||||
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys2, conns_attrs1, conns_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_conns = jax.vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
|
||||
Reference in New Issue
Block a user