use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -3,7 +3,6 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
@@ -36,11 +35,7 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.where(
|
||||
z > 0,
|
||||
jnp.maximum(z, 1e-7),
|
||||
jnp.minimum(z, -1e-7)
|
||||
)
|
||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user