adjust parameter for xor problem

This commit is contained in:
wls2002
2023-05-07 16:21:41 +08:00
parent a3b9bca866
commit 890c928b0f
6 changed files with 23 additions and 18 deletions

View File

@@ -26,7 +26,7 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
max_cnt = jnp.maximum(node_cnt1, node_cnt2) - 2
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
@@ -72,6 +72,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)