add save and load function for classes.

This commit is contained in:
wls2002
2024-06-09 20:33:02 +08:00
parent 374c05f5b7
commit 52e5d603f5
13 changed files with 70 additions and 42 deletions

View File

@@ -1,12 +1,7 @@
from utils import State from utils import State, StatefulBaseClass
class BaseAlgorithm: class BaseAlgorithm(StatefulBaseClass):
def setup(self, state=State()):
"""initialize the state of the algorithm"""
raise NotImplementedError
def ask(self, state: State): def ask(self, state: State):
"""require the population to be evaluated""" """require the population to be evaluated"""
raise NotImplementedError raise NotImplementedError

View File

@@ -1,9 +1,7 @@
from utils import State from utils import StatefulBaseClass
class BaseSubstrate: class BaseSubstrate(StatefulBaseClass):
def setup(self, state=State()):
return state
def make_nodes(self, query_res): def make_nodes(self, query_res):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,9 +1,6 @@
from utils import State from utils import StatefulBaseClass
class BaseCrossover: class BaseCrossover(StatefulBaseClass):
def setup(self, state=State()):
return state
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2): def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,9 +1,6 @@
from utils import State from utils import StatefulBaseClass
class BaseMutation: class BaseMutation(StatefulBaseClass):
def setup(self, state=State()):
return state
def __call__(self, state, randkey, genome, nodes, conns, new_node_key): def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,8 +1,8 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State from utils import State, StatefulBaseClass
class BaseGene: class BaseGene(StatefulBaseClass):
"Base class for node genes or connection genes." "Base class for node genes or connection genes."
fixed_attrs = [] fixed_attrs = []
custom_attrs = [] custom_attrs = []
@@ -10,9 +10,6 @@ class BaseGene:
def __init__(self): def __init__(self):
pass pass
def setup(self, state=State()):
return state
def new_identity_attrs(self, state): def new_identity_attrs(self, state):
# the attrs which do identity transformation, used in mutate add node # the attrs which do identity transformation, used in mutate add node
raise NotImplementedError raise NotImplementedError

View File

@@ -1,4 +1,4 @@
import jax, jax.numpy as jnp import jax
from .. import BaseGene from .. import BaseGene

View File

@@ -1,10 +1,10 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover from ..ga import BaseMutation, BaseCrossover
from utils import State from utils import State, StatefulBaseClass
class BaseGenome: class BaseGenome(StatefulBaseClass):
network_type = None network_type = None
def __init__( def __init__(

View File

@@ -1,15 +1,12 @@
from utils import State from utils import State, StatefulBaseClass
from ..genome import BaseGenome from ..genome import BaseGenome
class BaseSpecies: class BaseSpecies(StatefulBaseClass):
genome: BaseGenome genome: BaseGenome
pop_size: int pop_size: int
species_size: int species_size: int
def setup(self, state=State()):
return state
def ask(self, state: State): def ask(self, state: State):
raise NotImplementedError raise NotImplementedError

View File

@@ -6,10 +6,10 @@ from algorithm import BaseAlgorithm
from problem import BaseProblem from problem import BaseProblem
from problem.rl_env import RLEnv from problem.rl_env import RLEnv
from problem.func_fit import FuncFit from problem.func_fit import FuncFit
from utils import State from utils import State, StatefulBaseClass
class Pipeline: class Pipeline(StatefulBaseClass):
def __init__( def __init__(
self, self,
algorithm: BaseAlgorithm, algorithm: BaseAlgorithm,

View File

@@ -1,15 +1,11 @@
from typing import Callable from typing import Callable
from utils import State from utils import State, StatefulBaseClass
class BaseProblem: class BaseProblem(StatefulBaseClass):
jitable = None jitable = None
def setup(self, state: State = State()):
"""initialize the state of the problem"""
return state
def evaluate(self, state: State, randkey, act_func: Callable, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
"""evaluate one individual""" """evaluate one individual"""
raise NotImplementedError raise NotImplementedError

View File

@@ -3,3 +3,4 @@ from .aggregation import Agg, agg_func, AGG_ALL
from .tools import * from .tools import *
from .graph import * from .graph import *
from .state import State from .state import State
from .stateful_class import StatefulBaseClass

View File

@@ -30,6 +30,12 @@ class State:
def __repr__(self): def __repr__(self):
return f"State ({self.state_dict})" return f"State ({self.state_dict})"
def __getstate__(self):
return self.state_dict.copy()
def __setstate__(self, state):
self.__dict__["state_dict"] = state
def tree_flatten(self): def tree_flatten(self):
children = list(self.state_dict.values()) children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys()) aux_data = list(self.state_dict.keys())

View File

@@ -0,0 +1,44 @@
from typing import Optional
from . import State
import pickle
import datetime
import warnings
class StatefulBaseClass:
def setup(self, state=State()):
return state
def save(self, state: Optional[State] = None, path: Optional[str] = None):
if path is None:
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
path = f"./{self.__class__.__name__} {time}.pkl"
if state is not None:
self.__dict__["aux_for_state"] = state
with open(path, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str, with_state: bool = False, warning: bool = True):
with open(path, "rb") as f:
obj = pickle.load(f)
if with_state:
if "aux_for_state" not in obj.__dict__:
if warning:
warnings.warn(
"This object does not have state to load, return empty state",
category=UserWarning,
)
return obj, State()
state = obj.__dict__["aux_for_state"]
del obj.__dict__["aux_for_state"]
return obj, state
else:
if "aux_for_state" in obj.__dict__:
if warning:
warnings.warn(
"This object state to load, ignore it",
category=UserWarning,
)
del obj.__dict__["aux_for_state"]
return obj