modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||
from utils import (
|
||||
State,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
extract_conn_attrs,
|
||||
extract_node_attrs,
|
||||
)
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
|
||||
@@ -557,8 +564,10 @@ class DefaultSpecies(BaseSpecies):
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = jax.vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = jax.vmap(extract_node_attrs)(sr)
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
@@ -593,8 +602,11 @@ class DefaultSpecies(BaseSpecies):
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = jax.vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = jax.vmap(extract_conn_attrs)(sr)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
Reference in New Issue
Block a user