diff --git a/src/tensorneat/common/state.py b/src/tensorneat/common/state.py index ba0a09c..6bc445b 100644 --- a/src/tensorneat/common/state.py +++ b/src/tensorneat/common/state.py @@ -23,6 +23,13 @@ class State: raise ValueError(f"Key {key} does not exist in state") return State(**{**self.state_dict, **kwargs}) + def remove(self, *keys): + for key in keys: + if key not in self.registered_keys(): + raise ValueError(f"Key {key} does not exist in state") + return State(**{k: v for k, v in self.state_dict.items() if k not in keys}) + + def __getattr__(self, name): return self.state_dict[name]