optimize import
This commit is contained in:
@@ -2,8 +2,7 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
@@ -68,4 +68,4 @@ def rank_elements(array, reverse=False):
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
Reference in New Issue
Block a user