import jax.numpy as jnp EMPTY_NODE = jnp.full((1, 5), jnp.nan) print(EMPTY_NODE)