update class State. Add method register and update method update.

This commit is contained in:
wls2002
2024-05-25 16:05:47 +08:00
parent 25f66dc2fb
commit 3b2f917aee
2 changed files with 29 additions and 14 deletions

View File

@@ -7,7 +7,19 @@ class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def registered_keys(self):
return self.state_dict.keys()
def register(self, **kwargs):
for key in kwargs:
if key in self.registered_keys():
raise ValueError(f"Key {key} already exists in state")
return State(**{**self.state_dict, **kwargs})
def update(self, **kwargs):
for key in kwargs:
if key not in self.registered_keys():
raise ValueError(f"Key {key} does not exist in state")
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
@@ -26,4 +38,4 @@ class State:
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))
return cls(**dict(zip(aux_data, children)))