bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res
# @jit
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res = res.at[idx].set(i)
idx += 1
in_degree = in_degree.at[i].set(-1)
children = connections_enable[i, :]
in_degree = jnp.where(children, in_degree - 1, in_degree)
return res
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
:param pop_connections:
:return:
"""
return topological_sort(nodes, connections)
return topological_sort(pop_nodes, pop_connections)
@jit
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
# connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
@@ -191,7 +170,6 @@ if __name__ == '__main__':
]
)
print(topological_sort_debug(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))