From b30cbdc6696d5f33b342362a8d113692b63acce3 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 25 Feb 2025 10:50:50 +0800 Subject: [PATCH] add remove in state --- src/tensorneat/common/state.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/tensorneat/common/state.py b/src/tensorneat/common/state.py index ba0a09c..6bc445b 100644 --- a/src/tensorneat/common/state.py +++ b/src/tensorneat/common/state.py @@ -23,6 +23,13 @@ class State: raise ValueError(f"Key {key} does not exist in state") return State(**{**self.state_dict, **kwargs}) + def remove(self, *keys): + for key in keys: + if key not in self.registered_keys(): + raise ValueError(f"Key {key} does not exist in state") + return State(**{k: v for k, v in self.state_dict.items() if k not in keys}) + + def __getattr__(self, name): return self.state_dict[name]