use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -6,7 +6,7 @@ Only used in feed-forward networks.
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from .tools import fetch_first, I_INT
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
|
||||
@jit
|
||||
@@ -17,16 +17,16 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
res = jnp.full(in_degree.shape, I_INF)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
return i != I_INF
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
@@ -65,4 +65,4 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
return visited[from_idx]
|
||||
|
||||
Reference in New Issue
Block a user