debug-branch
This commit is contained in:
@@ -95,11 +95,11 @@ def topological_sort_debug(nodes: Array, connections: Array) -> Array:
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
batch version of topological_sort
|
||||
:param nodes:
|
||||
:param connections:
|
||||
:param pop_nodes:
|
||||
:param pop_connections:
|
||||
:return:
|
||||
"""
|
||||
return topological_sort(nodes, connections)
|
||||
@@ -175,17 +175,17 @@ if __name__ == '__main__':
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0, jnp.nan],
|
||||
[0, 0, 1, 1, jnp.nan],
|
||||
[0, 0, 0, 1, jnp.nan],
|
||||
[0, 0, 0, 0, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0, jnp.nan],
|
||||
[0, 0, 1, 1, jnp.nan],
|
||||
[0, 0, 0, 1, jnp.nan],
|
||||
[0, 0, 0, 0, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user