use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
55 lines
858 B
Python
55 lines
858 B
Python
import numpy as np
|
|
|
|
import jax.numpy as jnp
|
|
import jax
|
|
|
|
a = {1:2, 2:3, 4:5}
|
|
print(a.values())
|
|
|
|
a = jnp.array([1, 0, 1, 0, np.nan])
|
|
b = jnp.array([1, 1, 1, 1, 1])
|
|
c = jnp.array([1, 1, 1, 1, 1])
|
|
|
|
full = jnp.array([
|
|
[1, 1, 1],
|
|
[0, 1, 1],
|
|
[1, 1, 1],
|
|
[0, 1, 1],
|
|
])
|
|
|
|
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
|
|
|
|
aux0 = full[:, 0, None]
|
|
aux1 = full[:, 1, None]
|
|
|
|
print(aux0, aux0.shape)
|
|
|
|
print(jnp.concatenate([aux0, aux1], axis=1))
|
|
|
|
f_a = jnp.array([False, False, True, True])
|
|
f_b = jnp.array([True, False, False, False])
|
|
|
|
print(jnp.logical_and(f_a, f_b))
|
|
print(f_a & f_b)
|
|
|
|
print(f_a + jnp.nan * 0.0)
|
|
print(f_a + 1 * 0.0)
|
|
|
|
|
|
@jax.jit
|
|
def main():
|
|
return func('happy') + func('sad')
|
|
|
|
|
|
def func(x):
|
|
if x == 'happy':
|
|
return 1
|
|
else:
|
|
return 2
|
|
|
|
a = jnp.zeros((3, 3))
|
|
print(a.dtype)
|
|
|
|
c = None
|
|
b = 1 or c
|
|
print(b) |