update some files for save

This commit is contained in:
root
2024-07-15 14:25:51 +08:00
parent f032564a43
commit 6edf083d4f
12 changed files with 110 additions and 111 deletions

View File

@@ -1,3 +1,5 @@
import pickle
from jax.tree_util import register_pytree_node_class
@@ -39,6 +41,15 @@ class State:
def __contains__(self, item):
return item in self.state_dict
def save(self, file_name):
with open(file_name, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, file_name):
with open(file_name, "rb") as f:
return pickle.load(f)
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())

View File

@@ -9,30 +9,6 @@ class StatefulBaseClass:
def setup(self, state=State()):
return state
def save(self, state: Optional[State] = None, path: Optional[str] = None):
if path is None:
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
path = f"./{self.__class__.__name__} {time}.pkl"
if state is not None:
self.__dict__["aux_for_state"] = state
with open(path, "wb") as f:
pickle.dump(self, f)
def __getstate__(self):
# only pickle the picklable attributes
state = self.__dict__.copy()
non_picklable_keys = []
for key, value in state.items():
try:
pickle.dumps(value)
except Exception:
non_picklable_keys.append(key)
for key in non_picklable_keys:
state.pop(key)
return state
def show_config(self, registered_objects=None):
if registered_objects is None: # root call
registered_objects = []
@@ -47,27 +23,53 @@ class StatefulBaseClass:
config[str(key)] = str(value)
return config
@classmethod
def load(cls, path: str, with_state: bool = False, warning: bool = True):
with open(path, "rb") as f:
obj = pickle.load(f)
if with_state:
if "aux_for_state" not in obj.__dict__:
if warning:
warnings.warn(
"This object does not have state to load, return empty state",
category=UserWarning,
)
return obj, State()
state = obj.__dict__["aux_for_state"]
del obj.__dict__["aux_for_state"]
return obj, state
else:
if "aux_for_state" in obj.__dict__:
if warning:
warnings.warn(
"This object has state to load, ignore it",
category=UserWarning,
)
del obj.__dict__["aux_for_state"]
return obj
# TODO: Bug need be fixed
# def save(self, state: Optional[State] = None, path: Optional[str] = None):
# if path is None:
# time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
# path = f"./{self.__class__.__name__} {time}.pkl"
# if state is not None:
# self.__dict__["aux_for_state"] = state
# with open(path, "wb") as f:
# pickle.dump(self, f)
# def __getstate__(self):
# # only pickle the picklable attributes
# state = self.__dict__.copy()
# non_picklable_keys = []
# for key, value in state.items():
# try:
# pickle.dumps(value)
# except Exception as e:
# print(f"Cannot pickle key {key}: {e}")
# non_picklable_keys.append(key)
# for key in non_picklable_keys:
# state.pop(key)
# return state
# @classmethod
# def load(cls, path: str, with_state: bool = False, warning: bool = True):
# with open(path, "rb") as f:
# obj = pickle.load(f)
# if with_state:
# if "aux_for_state" not in obj.__dict__:
# if warning:
# warnings.warn(
# "This object does not have state to load, return empty state",
# category=UserWarning,
# )
# return obj, State()
# state = obj.__dict__["aux_for_state"]
# del obj.__dict__["aux_for_state"]
# return obj, state
# else:
# if "aux_for_state" in obj.__dict__:
# if warning:
# warnings.warn(
# "This object has state to load, ignore it",
# category=UserWarning,
# )
# del obj.__dict__["aux_for_state"]
# return obj