optimize import

This commit is contained in:
wls2002
2023-06-29 09:41:49 +08:00
parent d28cef1a87
commit 01b7731231
14 changed files with 29 additions and 58 deletions

View File

@@ -1,5 +1,8 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -2,5 +2,6 @@ from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .utils import unflatten_connections, I_INT, fetch_first, rank_elements
from .forward import create_forward_function
from .genome import initialize_genomes

View File

@@ -1,105 +1,85 @@
import jax
import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@jit
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@jit
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@jit
def relu_act(z):
return jnp.maximum(z, 0)
@jit
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z):
return z
@jit
def clamped_act(z):
return jnp.clip(z, -1, 1)
@jit
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@jit
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@jit
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@jit
def abs_act(z):
return jnp.abs(z)
@jit
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z):
return z ** 2
@jit
def cube_act(z):
return z ** 3

View File

@@ -1,7 +1,6 @@
import jax.numpy as jnp
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)

View File

@@ -6,8 +6,7 @@ See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGen
from typing import Tuple
import jax
from jax import jit, Array
from jax import numpy as jnp
from jax import jit, Array, numpy as jnp
@jit

View File

@@ -5,8 +5,7 @@ See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from jax import jit, vmap, Array, numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON

View File

@@ -1,6 +1,5 @@
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from jax import Array, numpy as jnp, jit, vmap
from .utils import I_INT

View File

@@ -4,10 +4,8 @@ Only used in feed-forward networks.
"""
import jax
from jax import jit, Array
from jax import numpy as jnp
from jax import jit, Array, numpy as jnp
# from .configs import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT

View File

@@ -4,11 +4,9 @@ The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from jax import numpy as jnp, jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection

View File

@@ -2,8 +2,7 @@ from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""

View File

@@ -6,10 +6,7 @@ import jax
from jax import jit, vmap
from configs import Configer
from algorithms.neat import initialize_genomes
from algorithms.neat.population import create_next_generation, speciate, update_species
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
from algorithms import neat
class Pipeline:
@@ -32,7 +29,7 @@ class Pipeline:
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
self.species_info = np.full((self.S, 3), np.nan)
self.species_info[0, :] = 0, -np.inf, 0
self.idx2species = np.zeros(self.P, dtype=np.float32)
@@ -47,9 +44,9 @@ class Pipeline:
self.evaluate_time = 0
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
self.pop_topological_sort = jit(vmap(topological_sort))
self.forward = create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
def ask(self):
"""
@@ -84,13 +81,13 @@ class Pipeline:
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config)

View File

@@ -4,8 +4,7 @@ contains operations on the population: creating the next generation and populati
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first, rank_elements
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
@jit

View File

@@ -4,8 +4,7 @@ import configparser
import numpy as np
from algorithms.neat.genome.activations import act_name2func
from algorithms.neat.genome.aggregations import agg_name2func
from algorithms.neat import act_name2func, agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [

View File

@@ -1,8 +1,8 @@
from functools import partial
import jax
from jax import numpy as jnp, jit
@partial(jit, static_argnames=['reverse'])
def rank_element(array, reverse=False):
"""