odify genome for the official release
This commit is contained in:
111
tensorneat/common/tools.py
Normal file
111
tensorneat/common/tools.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
# TODO: strange implementation
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
|
||||
)
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(randkey, mask, default=I_INF) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["reverse"])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float(
|
||||
randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate
|
||||
):
|
||||
"""
|
||||
mutate a float value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, mutate_rate) -> add noise
|
||||
r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val),
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int(randkey, val, options, replace_rate):
|
||||
"""
|
||||
mutate an int value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
"""
|
||||
find the index of the minimum element in the array, but only consider the element with True mask
|
||||
"""
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
Reference in New Issue
Block a user