update recurrent genome
This commit is contained in:
@@ -6,12 +6,11 @@ from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
# TODO: strange implementation
|
||||
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
|
||||
)
|
||||
target_dim = arr.ndim + idx.ndim - 1
|
||||
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user