complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -20,12 +20,10 @@ class State:
return f"State ({self.state_dict})"
def tree_flatten(self):
print('tree_flatten_cal')
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
print('tree_unflatten_cal')
return cls(**dict(zip(aux_data, children)))