Files
tensorneat-mend/examples/a.py
2023-07-21 15:03:12 +08:00

12 lines
178 B
Python

import numpy as np
import jax.numpy as jnp
a = jnp.zeros((5, 5))
k1 = jnp.array([1, 2, 3])
k2 = jnp.array([2, 3, 4])
v = jnp.array([1, 1, 1])
a = a.at[k1, k2].set(v)
print(a)