use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -3,7 +3,6 @@ import jax.numpy as jnp
class Act:
@staticmethod
def sigmoid(z):
z = jnp.clip(5 * z, -10, 10)
@@ -36,11 +35,7 @@ class Act:
@staticmethod
def inv(z):
z = jnp.where(
z > 0,
jnp.maximum(z, 1e-7),
jnp.minimum(z, -1e-7)
)
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
return 1 / z
@staticmethod