48 lines
767 B
Python
48 lines
767 B
Python
import numpy as np
|
|
|
|
import jax.numpy as jnp
|
|
import jax
|
|
|
|
a = jnp.array([1, 0, 1, 0, np.nan])
|
|
b = jnp.array([1, 1, 1, 1, 1])
|
|
c = jnp.array([1, 1, 1, 1, 1])
|
|
|
|
full = jnp.array([
|
|
[1, 1, 1],
|
|
[0, 1, 1],
|
|
[1, 1, 1],
|
|
[0, 1, 1],
|
|
])
|
|
|
|
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
|
|
|
|
aux0 = full[:, 0, None]
|
|
aux1 = full[:, 1, None]
|
|
|
|
print(aux0, aux0.shape)
|
|
|
|
print(jnp.concatenate([aux0, aux1], axis=1))
|
|
|
|
f_a = jnp.array([False, False, True, True])
|
|
f_b = jnp.array([True, False, False, False])
|
|
|
|
print(jnp.logical_and(f_a, f_b))
|
|
print(f_a & f_b)
|
|
|
|
print(f_a + jnp.nan * 0.0)
|
|
print(f_a + 1 * 0.0)
|
|
|
|
|
|
@jax.jit
|
|
def main():
|
|
return func('happy') + func('sad')
|
|
|
|
|
|
def func(x):
|
|
if x == 'happy':
|
|
return 1
|
|
else:
|
|
return 2
|
|
|
|
|
|
print(main()) |