Files
tensorneat-mend/examples/state_test.py
2023-07-14 17:27:22 +08:00

15 lines
194 B
Python

import jax
from algorithm.state import State
@jax.jit
def func(state: State, a):
return state.update(a=a)
state = State(c=1, b=2)
print(state)
state = func(state, 1111111)
print(state)