finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -3,6 +3,9 @@ import numpy as np
import jax.numpy as jnp
import jax
a = {1:2, 2:3, 4:5}
print(a.values())
a = jnp.array([1, 0, 1, 0, np.nan])
b = jnp.array([1, 1, 1, 1, 1])
c = jnp.array([1, 1, 1, 1, 1])
@@ -44,5 +47,9 @@ def func(x):
else:
return 2
a = jnp.zeros((3, 3))
print(a.dtype)
print(main())
c = None
b = 1 or c
print(b)