Files
tensorneat-mend/tensorneat/common/state.py
2024-07-10 11:24:11 +08:00

50 lines
1.4 KiB
Python

from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class State:
def __init__(self, **kwargs):
self.__dict__["state_dict"] = kwargs
def registered_keys(self):
return self.state_dict.keys()
def register(self, **kwargs):
for key in kwargs:
if key in self.registered_keys():
raise ValueError(f"Key {key} already exists in state")
return State(**{**self.state_dict, **kwargs})
def update(self, **kwargs):
for key in kwargs:
if key not in self.registered_keys():
raise ValueError(f"Key {key} does not exist in state")
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
return self.state_dict[name]
def __setattr__(self, name, value):
raise AttributeError("State is immutable")
def __repr__(self):
return f"State ({self.state_dict})"
def __getstate__(self):
return self.state_dict.copy()
def __setstate__(self, state):
self.__dict__["state_dict"] = state
def __contains__(self, item):
return item in self.state_dict
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))