optimize import

This commit is contained in:
wls2002
2023-06-29 09:41:49 +08:00
parent d28cef1a87
commit 01b7731231
14 changed files with 29 additions and 58 deletions

View File

@@ -2,5 +2,6 @@ from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .utils import unflatten_connections, I_INT, fetch_first, rank_elements
from .forward import create_forward_function
from .genome import initialize_genomes

View File

@@ -1,105 +1,85 @@
import jax
import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@jit
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@jit
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@jit
def relu_act(z):
return jnp.maximum(z, 0)
@jit
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z):
return z
@jit
def clamped_act(z):
return jnp.clip(z, -1, 1)
@jit
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@jit
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@jit
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@jit
def abs_act(z):
return jnp.abs(z)
@jit
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z):
return z ** 2
@jit
def cube_act(z):
return z ** 3

View File

@@ -1,7 +1,6 @@
import jax.numpy as jnp
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)

View File

@@ -6,8 +6,7 @@ See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGen
from typing import Tuple
import jax
from jax import jit, Array
from jax import numpy as jnp
from jax import jit, Array, numpy as jnp
@jit

View File

@@ -5,8 +5,7 @@ See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from jax import jit, vmap, Array, numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON

View File

@@ -1,6 +1,5 @@
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from jax import Array, numpy as jnp, jit, vmap
from .utils import I_INT

View File

@@ -4,10 +4,8 @@ Only used in feed-forward networks.
"""
import jax
from jax import jit, Array
from jax import numpy as jnp
from jax import jit, Array, numpy as jnp
# from .configs import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT

View File

@@ -4,11 +4,9 @@ The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from jax import numpy as jnp, jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection

View File

@@ -2,8 +2,7 @@ from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
@@ -68,4 +68,4 @@ def rank_elements(array, reverse=False):
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
return jnp.argsort(jnp.argsort(array))