use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -6,7 +6,7 @@ Only used in feed-forward networks.
import jax
from jax import jit, Array, numpy as jnp
from .tools import fetch_first, I_INT
from .tools import fetch_first, I_INF
@jit
@@ -17,16 +17,16 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
"""
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
res = jnp.full(in_degree.shape, I_INT)
res = jnp.full(in_degree.shape, I_INF)
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
i = fetch_first(in_degree_ == 0.0)
return i != I_INF
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
i = fetch_first(in_degree_ == 0.0)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
@@ -65,4 +65,4 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
return visited[from_idx]