change a lot

This commit is contained in:
wls2002
2023-07-17 17:39:12 +08:00
parent a0a1ef6c58
commit f4763ebcea
21 changed files with 1060 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
from jax.tree_util import register_pytree_node_class, tree_map
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
@@ -20,10 +20,12 @@ class State:
return f"State ({self.state_dict})"
def tree_flatten(self):
print('tree_flatten_cal')
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
print('tree_unflatten_cal')
return cls(**dict(zip(aux_data, children)))