debug-branch
This commit is contained in:
@@ -48,7 +48,7 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
|
||||
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
||||
if gene_type == 'node':
|
||||
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
else: # connection
|
||||
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
|
||||
@@ -64,7 +64,17 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, 0))
|
||||
@vmap
|
||||
def batch_homologous_node_distance(b_n1, b_n2):
|
||||
return homologous_node_distance(b_n1, b_n2)
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_homologous_connection_distance(b_c1, b_c2):
|
||||
return homologous_connection_distance(b_c1, b_c2)
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_node_distance(n1, n2):
|
||||
d = 0
|
||||
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||
@@ -74,7 +84,7 @@ def homologous_node_distance(n1, n2):
|
||||
return d
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, 0))
|
||||
@jit
|
||||
def homologous_connection_distance(c1, c2):
|
||||
d = 0
|
||||
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||
|
||||
Reference in New Issue
Block a user