Files
tensorneat-mend/examples/jax_playground.py
2023-05-06 11:35:44 +08:00

14 lines
231 B
Python

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
seed = jax.random.PRNGKey(42)
seed, *subkeys = random.split(seed, 3)
c = random.split(seed, 1)
print(seed, subkeys)
print(c)