adjust default parameter; successful run recurrent-xor example

This commit is contained in:
root
2024-07-11 10:57:43 +08:00
parent 4a631f9464
commit 9bad577d89
18 changed files with 118 additions and 136 deletions

View File

@@ -9,7 +9,7 @@ from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Union
from typing import Callable, Union
name2sympy = {
"sigmoid": SympySigmoid,
@@ -34,7 +34,7 @@ name2sympy = {
}
def convert_to_sympy(func: Union[str, callable]):
def convert_to_sympy(func: Union[str, Callable]):
if isinstance(func, str):
name = func
else:

View File

@@ -31,7 +31,7 @@ class Act:
@staticmethod
def standard_tanh(z):
z =5 * z / sigma_3
z = 5 * z / sigma_3
return jnp.tanh(z) # (-1, 1)
@staticmethod
@@ -52,7 +52,6 @@ class Act:
@staticmethod
def identity(z):
z = jnp.clip(z, -sigma_3, sigma_3)
return z
@staticmethod

View File

@@ -54,13 +54,6 @@ class SympyStandardSigmoid(sp.Function):
def eval(cls, z):
return SympySigmoid_(5 * z / sigma_3)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# z = 1 / (1 + backend.exp(-z))
#
# return z # (0, 1)
class SympyTanh(sp.Function):
@classmethod
@@ -68,11 +61,6 @@ class SympyTanh(sp.Function):
z = 5 * z / sigma_3
return sp.tanh(z) * sigma_3
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function):
@classmethod
@@ -80,11 +68,6 @@ class SympyStandardTanh(sp.Function):
z = 5 * z / sigma_3
return sp.tanh(z)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function):
@classmethod
@@ -143,14 +126,7 @@ class SympyLelu(sp.Function):
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -sigma_3, sigma_3)
return z
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.clip(z, -sigma_3, sigma_3)
return z
class SympyInv(sp.Function):