optimize import

This commit is contained in:
wls2002
2023-06-29 09:41:49 +08:00
parent d28cef1a87
commit 01b7731231
14 changed files with 29 additions and 58 deletions

View File

@@ -2,8 +2,7 @@ from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
@@ -68,4 +68,4 @@ def rank_elements(array, reverse=False):
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
return jnp.argsort(jnp.argsort(array))