12 lines
178 B
Python
12 lines
178 B
Python
import numpy as np
|
|
import jax.numpy as jnp
|
|
|
|
a = jnp.zeros((5, 5))
|
|
k1 = jnp.array([1, 2, 3])
|
|
k2 = jnp.array([2, 3, 4])
|
|
v = jnp.array([1, 1, 1])
|
|
|
|
a = a.at[k1, k2].set(v)
|
|
|
|
print(a)
|