update recurrent genome

This commit is contained in:
root
2024-07-10 16:27:49 +08:00
parent 1d606eb1c3
commit 649d4b0552
8 changed files with 490 additions and 46 deletions

16
examples/tmp2.py Normal file
View File

@@ -0,0 +1,16 @@
import jax, jax.numpy as jnp
arr = jnp.ones((10, 10))
a = jnp.array([
[1, 2, 3],
[4, 5, 6]
])
def attach_with_inf(arr, idx):
target_dim = arr.ndim + idx.ndim - 1
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
b = attach_with_inf(arr, a)
print(b)