modify act funcs and sympy act funcs;

add dense and advance initialize genome;
add input_transform for genome;
This commit is contained in:
wls2002
2024-06-18 16:01:11 +08:00
parent 907314bc80
commit ce8015d22c
7 changed files with 222 additions and 39 deletions

View File

@@ -20,7 +20,6 @@ name2sympy = {
"relu": SympyRelu,
"lelu": SympyLelu,
"identity": SympyIdentity,
"clamped": SympyClamped,
"inv": SympyInv,
"log": SympyLog,
"exp": SympyExp,

View File

@@ -2,6 +2,9 @@ import jax
import jax.numpy as jnp
sigma_3 = 2.576
class Act:
@staticmethod
def name2func(name):
@@ -9,35 +12,42 @@ class Act:
@staticmethod
def sigmoid(z):
z = jnp.clip(5 * z, -10, 10)
return 1 / (1 + jnp.exp(-z))
z = jnp.clip(5 * z / sigma_3, -5, 5)
z = 1 / (1 + jnp.exp(-z))
return z * sigma_3 # (0, sigma_3)
@staticmethod
def tanh(z):
z = jnp.clip(0.6*z, -3, 3)
return jnp.tanh(z)
z = jnp.clip(5 * z / sigma_3, -5, 5)
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
@staticmethod
def standard_tanh(z):
z = jnp.clip(5 * z / sigma_3, -5, 5)
return jnp.tanh(z) # (-1, 1)
@staticmethod
def sin(z):
return jnp.sin(z)
z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
@staticmethod
def relu(z):
return jnp.maximum(z, 0)
z = jnp.clip(z, -sigma_3, sigma_3)
return jnp.maximum(z, 0) # (0, sigma_3)
@staticmethod
def lelu(z):
leaky = 0.005
z = jnp.clip(z, -sigma_3, sigma_3)
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def identity(z):
z = jnp.clip(z, -sigma_3, sigma_3)
return z
@staticmethod
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv(z):
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
@@ -55,6 +65,7 @@ class Act:
@staticmethod
def abs(z):
z = jnp.clip(z, -1, 1)
return jnp.abs(z)
@@ -65,7 +76,6 @@ ACT_ALL = (
Act.relu,
Act.lelu,
Act.identity,
Act.clamped,
Act.inv,
Act.log,
Act.exp,

View File

@@ -2,6 +2,9 @@ import sympy as sp
import numpy as np
sigma_3 = 2.576
class SympyClip(sp.Function):
@classmethod
def eval(cls, val, min_val, max_val):
@@ -26,14 +29,17 @@ class SympySigmoid(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z, -10, 10)
return 1 / (1 + sp.exp(-z))
z = SympyClip(5 * z / sigma_3, -5, 5)
z = 1 / (1 + sp.exp(-z))
return z * sigma_3
return None
@staticmethod
def numerical_eval(z, backend=np):
z = backend.clip(5 * z, -10, 10)
return 1 / (1 + backend.exp(-z))
z = backend.clip(5 * z / sigma_3, -5, 5)
z = 1 / (1 + backend.exp(-z))
return z * sigma_3 # (0, sigma_3)
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
@@ -46,36 +52,56 @@ class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(0.6 * z, -3, 3)
z = SympyClip(5 * z / sigma_3, -5, 5)
return sp.tanh(z) * sigma_3
return None
@staticmethod
def numerical_eval(z, backend=np):
z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z / sigma_3, -5, 5)
return sp.tanh(z)
return None
@staticmethod
def numerical_eval(z, backend=np):
z = backend.clip(0.6*z, -3, 3)
return backend.tanh(z)
z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function):
@classmethod
def eval(cls, z):
return sp.sin(z)
if z.is_Number:
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.sin(z)
z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyRelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
return sp.Piecewise((z, z > 0), (0, True))
z = SympyClip(z, -sigma_3, sigma_3)
return sp.Max(z, 0) # (0, sigma_3)
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.maximum(z, 0)
z = backend.clip(z, -sigma_3, sigma_3)
return backend.maximum(z, 0) # (0, sigma_3)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
@@ -107,21 +133,14 @@ class SympyLelu(sp.Function):
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
return z
if z.is_Number:
z = SympyClip(z, -sigma_3, sigma_3)
return z
return None
@staticmethod
def numerical_eval(z, backend=np):
return z
class SympyClamped(sp.Function):
@classmethod
def eval(cls, z):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z, backend=np):
return backend.clip(z, -1, 1)
return backend.clip(z, -sigma_3, sigma_3)
class SympyInv(sp.Function):