Files
tensorneat-mend/examples/a.py
wls2002 0cb2f9473d finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
2023-06-25 00:26:52 +08:00

55 lines
858 B
Python

import numpy as np
import jax.numpy as jnp
import jax
a = {1:2, 2:3, 4:5}
print(a.values())
a = jnp.array([1, 0, 1, 0, np.nan])
b = jnp.array([1, 1, 1, 1, 1])
c = jnp.array([1, 1, 1, 1, 1])
full = jnp.array([
[1, 1, 1],
[0, 1, 1],
[1, 1, 1],
[0, 1, 1],
])
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
aux0 = full[:, 0, None]
aux1 = full[:, 1, None]
print(aux0, aux0.shape)
print(jnp.concatenate([aux0, aux1], axis=1))
f_a = jnp.array([False, False, True, True])
f_b = jnp.array([True, False, False, False])
print(jnp.logical_and(f_a, f_b))
print(f_a & f_b)
print(f_a + jnp.nan * 0.0)
print(f_a + 1 * 0.0)
@jax.jit
def main():
return func('happy') + func('sad')
def func(x):
if x == 'happy':
return 1
else:
return 2
a = jnp.zeros((3, 3))
print(a.dtype)
c = None
b = 1 or c
print(b)