complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -9,12 +9,11 @@ from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, connections: Array) -> Array:
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:param conns: connections array
:return: topological sorted sequence
Example:
@@ -25,12 +24,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
@@ -41,8 +34,8 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
res = jnp.full(in_degree.shape, I_INT)
def cond_fun(carry):
@@ -59,7 +52,7 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
children = conns[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
@@ -67,7 +60,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.