import jax.numpy as jnp a = jnp.zeros((0, 9, 9)) print(a)