update recurrent genome

This commit is contained in:
root
2024-07-10 16:27:49 +08:00
parent 1d606eb1c3
commit 649d4b0552
8 changed files with 490 additions and 46 deletions

View File

@@ -6,12 +6,11 @@ from jax import numpy as jnp, Array, jit, vmap
I_INF = np.iinfo(jnp.int32).max # infinite int
# TODO: strange implementation
def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim
expand_idx = jnp.expand_dims(
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
)
target_dim = arr.ndim + idx.ndim - 1
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])