debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -48,7 +48,7 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
@@ -64,7 +64,17 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
@partial(vmap, in_axes=(0, 0))
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
@@ -74,7 +84,7 @@ def homologous_node_distance(n1, n2):
return d
@partial(vmap, in_axes=(0, 0))
@jit
def homologous_connection_distance(c1, c2):
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight