diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index 2c9ee50..7493be7 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -1,12 +1,7 @@ -from utils import State +from utils import State, StatefulBaseClass -class BaseAlgorithm: - def setup(self, state=State()): - """initialize the state of the algorithm""" - - raise NotImplementedError - +class BaseAlgorithm(StatefulBaseClass): def ask(self, state: State): """require the population to be evaluated""" raise NotImplementedError diff --git a/tensorneat/algorithm/hyperneat/substrate/base.py b/tensorneat/algorithm/hyperneat/substrate/base.py index 6172c8b..4a00925 100644 --- a/tensorneat/algorithm/hyperneat/substrate/base.py +++ b/tensorneat/algorithm/hyperneat/substrate/base.py @@ -1,9 +1,7 @@ -from utils import State +from utils import StatefulBaseClass -class BaseSubstrate: - def setup(self, state=State()): - return state +class BaseSubstrate(StatefulBaseClass): def make_nodes(self, query_res): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index b59ce6c..d34921a 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -1,9 +1,6 @@ -from utils import State +from utils import StatefulBaseClass -class BaseCrossover: - def setup(self, state=State()): - return state - +class BaseCrossover(StatefulBaseClass): def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py index 68bd05a..fa5cf73 100644 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ b/tensorneat/algorithm/neat/ga/mutation/base.py @@ -1,9 +1,6 @@ -from utils import State +from utils import StatefulBaseClass -class BaseMutation: - def setup(self, state=State()): - return state - +class BaseMutation(StatefulBaseClass): def __call__(self, state, randkey, genome, nodes, conns, new_node_key): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 8ef4f71..d3cbad6 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -1,8 +1,8 @@ import jax, jax.numpy as jnp -from utils import State +from utils import State, StatefulBaseClass -class BaseGene: +class BaseGene(StatefulBaseClass): "Base class for node genes or connection genes." fixed_attrs = [] custom_attrs = [] @@ -10,9 +10,6 @@ class BaseGene: def __init__(self): pass - def setup(self, state=State()): - return state - def new_identity_attrs(self, state): # the attrs which do identity transformation, used in mutate add node raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index 66de56e..39d72fe 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -1,4 +1,4 @@ -import jax, jax.numpy as jnp +import jax from .. import BaseGene diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index e97d828..71170df 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -1,10 +1,10 @@ import jax, jax.numpy as jnp from ..gene import BaseNodeGene, BaseConnGene from ..ga import BaseMutation, BaseCrossover -from utils import State +from utils import State, StatefulBaseClass -class BaseGenome: +class BaseGenome(StatefulBaseClass): network_type = None def __init__( diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index f2b9cc1..4654dba 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -1,15 +1,12 @@ -from utils import State +from utils import State, StatefulBaseClass from ..genome import BaseGenome -class BaseSpecies: +class BaseSpecies(StatefulBaseClass): genome: BaseGenome pop_size: int species_size: int - def setup(self, state=State()): - return state - def ask(self, state: State): raise NotImplementedError diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index a31808b..798d424 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -6,10 +6,10 @@ from algorithm import BaseAlgorithm from problem import BaseProblem from problem.rl_env import RLEnv from problem.func_fit import FuncFit -from utils import State +from utils import State, StatefulBaseClass -class Pipeline: +class Pipeline(StatefulBaseClass): def __init__( self, algorithm: BaseAlgorithm, diff --git a/tensorneat/problem/base.py b/tensorneat/problem/base.py index 712c989..bf78f78 100644 --- a/tensorneat/problem/base.py +++ b/tensorneat/problem/base.py @@ -1,15 +1,11 @@ from typing import Callable -from utils import State +from utils import State, StatefulBaseClass -class BaseProblem: +class BaseProblem(StatefulBaseClass): jitable = None - def setup(self, state: State = State()): - """initialize the state of the problem""" - return state - def evaluate(self, state: State, randkey, act_func: Callable, params): """evaluate one individual""" raise NotImplementedError diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index c46db75..f826e4b 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -3,3 +3,4 @@ from .aggregation import Agg, agg_func, AGG_ALL from .tools import * from .graph import * from .state import State +from .stateful_class import StatefulBaseClass diff --git a/tensorneat/utils/state.py b/tensorneat/utils/state.py index 84a4fe6..bac0a11 100644 --- a/tensorneat/utils/state.py +++ b/tensorneat/utils/state.py @@ -30,6 +30,12 @@ class State: def __repr__(self): return f"State ({self.state_dict})" + def __getstate__(self): + return self.state_dict.copy() + + def __setstate__(self, state): + self.__dict__["state_dict"] = state + def tree_flatten(self): children = list(self.state_dict.values()) aux_data = list(self.state_dict.keys()) diff --git a/tensorneat/utils/stateful_class.py b/tensorneat/utils/stateful_class.py new file mode 100644 index 0000000..e865531 --- /dev/null +++ b/tensorneat/utils/stateful_class.py @@ -0,0 +1,44 @@ +from typing import Optional +from . import State +import pickle +import datetime +import warnings + + +class StatefulBaseClass: + def setup(self, state=State()): + return state + + def save(self, state: Optional[State] = None, path: Optional[str] = None): + if path is None: + time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + path = f"./{self.__class__.__name__} {time}.pkl" + if state is not None: + self.__dict__["aux_for_state"] = state + with open(path, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, path: str, with_state: bool = False, warning: bool = True): + with open(path, "rb") as f: + obj = pickle.load(f) + if with_state: + if "aux_for_state" not in obj.__dict__: + if warning: + warnings.warn( + "This object does not have state to load, return empty state", + category=UserWarning, + ) + return obj, State() + state = obj.__dict__["aux_for_state"] + del obj.__dict__["aux_for_state"] + return obj, state + else: + if "aux_for_state" in obj.__dict__: + if warning: + warnings.warn( + "This object state to load, ignore it", + category=UserWarning, + ) + del obj.__dict__["aux_for_state"] + return obj