14 lines
231 B
Python
14 lines
231 B
Python
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax import random
|
|
from jax import vmap, jit
|
|
|
|
|
|
seed = jax.random.PRNGKey(42)
|
|
seed, *subkeys = random.split(seed, 3)
|
|
|
|
|
|
c = random.split(seed, 1)
|
|
print(seed, subkeys)
|
|
print(c) |