68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
"""
|
|
Some graph algorithm implemented in jax.
|
|
Only used in feed-forward networks.
|
|
"""
|
|
|
|
import jax
|
|
from jax import jit, Array, numpy as jnp
|
|
|
|
from algorithm.utils import fetch_first, I_INT
|
|
|
|
|
|
@jit
|
|
def topological_sort(nodes: Array, conns: Array) -> Array:
|
|
"""
|
|
a jit-able version of topological_sort!
|
|
"""
|
|
|
|
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
|
res = jnp.full(in_degree.shape, I_INT)
|
|
|
|
def cond_fun(carry):
|
|
res_, idx_, in_degree_ = carry
|
|
i = fetch_first(in_degree_ == 0.)
|
|
return i != I_INT
|
|
|
|
def body_func(carry):
|
|
res_, idx_, in_degree_ = carry
|
|
i = fetch_first(in_degree_ == 0.)
|
|
|
|
# add to res and flag it is already in it
|
|
res_ = res_.at[idx_].set(i)
|
|
in_degree_ = in_degree_.at[i].set(-1)
|
|
|
|
# decrease in_degree of all its children
|
|
children = conns[i, :]
|
|
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
|
return res_, idx_ + 1, in_degree_
|
|
|
|
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
|
return res
|
|
|
|
|
|
@jit
|
|
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
|
"""
|
|
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
|
"""
|
|
|
|
conns = conns.at[from_idx, to_idx].set(True)
|
|
|
|
visited = jnp.full(nodes.shape[0], False)
|
|
new_visited = visited.at[to_idx].set(True)
|
|
|
|
def cond_func(carry):
|
|
visited_, new_visited_ = carry
|
|
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
|
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
|
return jnp.logical_not(end_cond1 | end_cond2)
|
|
|
|
def body_func(carry):
|
|
_, visited_ = carry
|
|
new_visited_ = jnp.dot(visited_, conns)
|
|
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
|
return visited_, new_visited_
|
|
|
|
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
|
return visited[from_idx]
|