add save function in pipeline

This commit is contained in:
wls2002
2024-06-16 21:47:53 +08:00
parent b9d6482d11
commit fb2ae5d2fa
10 changed files with 94 additions and 164 deletions

View File

@@ -14,7 +14,8 @@ class Act:
@staticmethod
def tanh(z):
return jnp.tanh(0.6 * z)
z = jnp.clip(0.6*z, -3, 3)
return jnp.tanh(z)
@staticmethod
def sin(z):

View File

@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
return sp.tanh(0.6 * z)
if z.is_Number:
z = SympyClip(0.6 * z, -3, 3)
return sp.tanh(z)
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.tanh(0.6 * z)
z = backend.clip(0.6*z, -3, 3)
return backend.tanh(z)
class SympySin(sp.Function):

View File

@@ -1,3 +1,4 @@
import json
from typing import Optional
from . import State
import pickle
@@ -18,6 +19,15 @@ class StatefulBaseClass:
with open(path, "wb") as f:
pickle.dump(self, f)
def show_config(self):
config = {}
for key, value in self.__dict__.items():
if isinstance(value, StatefulBaseClass):
config[str(key)] = value.show_config()
else:
config[str(key)] = str(value)
return config
@classmethod
def load(cls, path: str, with_state: bool = False, warning: bool = True):
with open(path, "rb") as f: