Files
tensorneat-mend/tensorneat/utils/tools.py
2024-03-26 21:58:27 +08:00

106 lines
2.9 KiB
Python

from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
:return:
"""
N = nodes.shape[0]
CL = conns.shape[1]
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
# put all attributes include enable in res
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
return res
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from small to large. default large to small
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
@jit
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@jit
def mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx