change a lot
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from jax.tree_util import register_pytree_node_class, tree_map
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
@@ -20,10 +20,12 @@ class State:
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def tree_flatten(self):
|
||||
print('tree_flatten_cal')
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
print('tree_unflatten_cal')
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
|
||||
Reference in New Issue
Block a user