modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.

This commit is contained in:
wls2002
2024-06-01 20:42:42 +08:00
parent 4ad9f0a85a
commit e65200a94e
14 changed files with 281 additions and 204 deletions

View File

@@ -1,5 +1,12 @@
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from utils import (
State,
rank_elements,
argmin_with_mask,
fetch_first,
extract_conn_attrs,
extract_node_attrs,
)
from ..genome import BaseGenome
from .base import BaseSpecies
@@ -557,8 +564,10 @@ class DefaultSpecies(BaseSpecies):
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
fr_attrs = jax.vmap(extract_node_attrs)(fr)
sr_attrs = jax.vmap(extract_node_attrs)(sr)
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
state, fr_attrs, sr_attrs
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
@@ -593,8 +602,11 @@ class DefaultSpecies(BaseSpecies):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = jax.vmap(extract_conn_attrs)(fr)
sr_attrs = jax.vmap(extract_conn_attrs)(sr)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)