optimize import

This commit is contained in:
wls2002
2023-06-29 09:41:49 +08:00
parent d28cef1a87
commit 01b7731231
14 changed files with 29 additions and 58 deletions

View File

@@ -1,5 +1,8 @@
""" """
contains operations on a single genome. e.g. forward, mutate, crossover, etc. contains operations on a single genome. e.g. forward, mutate, crossover, etc.
""" """
from .genome import create_forward_function, topological_sort, unflatten_connections from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate from .population import update_species, create_next_generation, speciate
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -2,5 +2,6 @@ from .mutate import mutate
from .distance import distance from .distance import distance
from .crossover import crossover from .crossover import crossover
from .graph import topological_sort, check_cycles from .graph import topological_sort, check_cycles
from .utils import unflatten_connections from .utils import unflatten_connections, I_INT, fetch_first, rank_elements
from .forward import create_forward_function from .forward import create_forward_function
from .genome import initialize_genomes

View File

@@ -1,105 +1,85 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z): def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60) z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z)) return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z): def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60) z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z) return jnp.tanh(z)
@jit
def sin_act(z): def sin_act(z):
z = jnp.clip(z * 5, -60, 60) z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z) return jnp.sin(z)
@jit
def gauss_act(z): def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4) z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2) return jnp.exp(-z ** 2)
@jit
def relu_act(z): def relu_act(z):
return jnp.maximum(z, 0) return jnp.maximum(z, 0)
@jit
def elu_act(z): def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1) return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z): def lelu_act(z):
leaky = 0.005 leaky = 0.005
return jnp.where(z > 0, z, leaky * z) return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z): def selu_act(z):
lam = 1.0507009873554804934193349852946 lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717 alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z): def softplus_act(z):
z = jnp.clip(z * 5, -60, 60) z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z)) return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z): def identity_act(z):
return z return z
@jit
def clamped_act(z): def clamped_act(z):
return jnp.clip(z, -1, 1) return jnp.clip(z, -1, 1)
@jit
def inv_act(z): def inv_act(z):
z = jnp.maximum(z, 1e-7) z = jnp.maximum(z, 1e-7)
return 1 / z return 1 / z
@jit
def log_act(z): def log_act(z):
z = jnp.maximum(z, 1e-7) z = jnp.maximum(z, 1e-7)
return jnp.log(z) return jnp.log(z)
@jit
def exp_act(z): def exp_act(z):
z = jnp.clip(z, -60, 60) z = jnp.clip(z, -60, 60)
return jnp.exp(z) return jnp.exp(z)
@jit
def abs_act(z): def abs_act(z):
return jnp.abs(z) return jnp.abs(z)
@jit
def hat_act(z): def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z)) return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z): def square_act(z):
return z ** 2 return z ** 2
@jit
def cube_act(z): def cube_act(z):
return z ** 3 return z ** 3

View File

@@ -1,7 +1,6 @@
import jax.numpy as jnp import jax.numpy as jnp
def sum_agg(z): def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z) z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0) return jnp.sum(z, axis=0)

View File

@@ -6,8 +6,7 @@ See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGen
from typing import Tuple from typing import Tuple
import jax import jax
from jax import jit, Array from jax import jit, Array, numpy as jnp
from jax import numpy as jnp
@jit @jit

View File

@@ -5,8 +5,7 @@ See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
""" """
from typing import Dict from typing import Dict
from jax import jit, vmap, Array from jax import jit, vmap, Array, numpy as jnp
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON from .utils import EMPTY_NODE, EMPTY_CON

View File

@@ -1,6 +1,5 @@
import jax import jax
from jax import Array, numpy as jnp from jax import Array, numpy as jnp, jit, vmap
from jax import jit, vmap
from .utils import I_INT from .utils import I_INT

View File

@@ -4,10 +4,8 @@ Only used in feed-forward networks.
""" """
import jax import jax
from jax import jit, Array from jax import jit, Array, numpy as jnp
from jax import numpy as jnp
# from .configs import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT from algorithms.neat.genome.utils import fetch_first, I_INT

View File

@@ -4,11 +4,9 @@ The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
""" """
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import partial
import jax import jax
from jax import numpy as jnp from jax import numpy as jnp, jit, Array
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection

View File

@@ -2,8 +2,7 @@ from functools import partial
import numpy as np import numpy as np
import jax import jax
from jax import numpy as jnp, Array from jax import numpy as jnp, Array, jit, vmap
from jax import jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan) EMPTY_NODE = np.full((1, 5), jnp.nan)
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target) mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default) return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse']) @partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False): def rank_elements(array, reverse=False):
""" """

View File

@@ -6,10 +6,7 @@ import jax
from jax import jit, vmap from jax import jit, vmap
from configs import Configer from configs import Configer
from algorithms.neat import initialize_genomes from algorithms import neat
from algorithms.neat.population import create_next_generation, speciate, update_species
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
class Pipeline: class Pipeline:
@@ -32,7 +29,7 @@ class Pipeline:
self.generation = 0 self.generation = 0
self.best_genome = None self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config) self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
self.species_info = np.full((self.S, 3), np.nan) self.species_info = np.full((self.S, 3), np.nan)
self.species_info[0, :] = 0, -np.inf, 0 self.species_info[0, :] = 0, -np.inf, 0
self.idx2species = np.zeros(self.P, dtype=np.float32) self.idx2species = np.zeros(self.P, dtype=np.float32)
@@ -47,9 +44,9 @@ class Pipeline:
self.evaluate_time = 0 self.evaluate_time = 0
self.pop_unflatten_connections = jit(vmap(unflatten_connections)) self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(topological_sort)) self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = create_forward_function(config) self.forward = neat.create_forward_function(config)
def ask(self): def ask(self):
""" """
@@ -84,13 +81,13 @@ class Pipeline:
k1, k2, self.randkey = jax.random.split(self.randkey, 3) k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \ self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes, neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config) self.center_cons, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser, self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config) elite_mask, self.generation, self.jit_config)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate( self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation, self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config) self.jit_config)

View File

@@ -4,8 +4,7 @@ contains operations on the population: creating the next generation and populati
import jax import jax
from jax import jit, vmap, Array, numpy as jnp from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
from .genome.utils import I_INT, fetch_first, rank_elements
@jit @jit

View File

@@ -4,8 +4,7 @@ import configparser
import numpy as np import numpy as np
from algorithms.neat.genome.activations import act_name2func from algorithms.neat import act_name2func, agg_name2func
from algorithms.neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [ jit_config_keys = [

View File

@@ -1,8 +1,8 @@
from functools import partial from functools import partial
import jax
from jax import numpy as jnp, jit from jax import numpy as jnp, jit
@partial(jit, static_argnames=['reverse']) @partial(jit, static_argnames=['reverse'])
def rank_element(array, reverse=False): def rank_element(array, reverse=False):
""" """
@@ -14,5 +14,5 @@ def rank_element(array, reverse=False):
return jnp.argsort(jnp.argsort(array)) return jnp.argsort(jnp.argsort(array))
a = jnp.array([1 ,5, 3, 5, 2, 1, 0]) a = jnp.array([1, 5, 3, 5, 2, 1, 0])
print(rank_element(a, reverse=True)) print(rank_element(a, reverse=True))