complete normal neat algorithm
This commit is contained in:
@@ -20,12 +20,10 @@ class State:
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def tree_flatten(self):
|
||||
print('tree_flatten_cal')
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
print('tree_unflatten_cal')
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
|
||||
Reference in New Issue
Block a user