create state

This commit is contained in:
wls2002
2023-07-14 17:27:22 +08:00
parent 7265e33c43
commit a0a1ef6c58
41 changed files with 43 additions and 2882 deletions

14
examples/state_test.py Normal file
View File

@@ -0,0 +1,14 @@
import jax
from algorithm.state import State
@jax.jit
def func(state: State, a):
return state.update(a=a)
state = State(c=1, b=2)
print(state)
state = func(state, 1111111)
print(state)