add save function in pipeline
This commit is contained in:
@@ -14,7 +14,8 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
return jnp.tanh(0.6 * z)
|
||||
z = jnp.clip(0.6*z, -3, 3)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
|
||||
@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.tanh(0.6 * z)
|
||||
if z.is_Number:
|
||||
z = SympyClip(0.6 * z, -3, 3)
|
||||
return sp.tanh(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.tanh(0.6 * z)
|
||||
z = backend.clip(0.6*z, -3, 3)
|
||||
return backend.tanh(z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from . import State
|
||||
import pickle
|
||||
@@ -18,6 +19,15 @@ class StatefulBaseClass:
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, StatefulBaseClass):
|
||||
config[str(key)] = value.show_config()
|
||||
else:
|
||||
config[str(key)] = str(value)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
||||
with open(path, "rb") as f:
|
||||
|
||||
Reference in New Issue
Block a user