bug down! Here it can solve xor successfully!
This commit is contained in:
@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
# @jit
|
||||
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
|
||||
connections_enable = connections[1, :, :] == 1
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
idx = 0
|
||||
|
||||
for _ in range(in_degree.shape[0]):
|
||||
i = fetch_first(in_degree == 0.)
|
||||
if i == I_INT:
|
||||
break
|
||||
res = res.at[idx].set(i)
|
||||
idx += 1
|
||||
in_degree = in_degree.at[i].set(-1)
|
||||
children = connections_enable[i, :]
|
||||
in_degree = jnp.where(children, in_degree - 1, in_degree)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
:param pop_connections:
|
||||
:return:
|
||||
"""
|
||||
return topological_sort(nodes, connections)
|
||||
return topological_sort(pop_nodes, pop_connections)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
# connections_enable = connections[0, :, :] == 1
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
@@ -191,7 +170,6 @@ if __name__ == '__main__':
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort_debug(nodes, connections))
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
|
||||
Reference in New Issue
Block a user