optimize import
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||
from .population import update_species, create_next_generation, speciate
|
||||
|
||||
from .genome.activations import act_name2func
|
||||
from .genome.aggregations import agg_name2func
|
||||
|
||||
@@ -2,5 +2,6 @@ from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .utils import unflatten_connections, I_INT, fetch_first, rank_elements
|
||||
from .forward import create_forward_function
|
||||
from .genome import initialize_genomes
|
||||
|
||||
@@ -1,105 +1,85 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
|
||||
@jit
|
||||
def sigmoid_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
|
||||
@jit
|
||||
def tanh_act(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
|
||||
@jit
|
||||
def sin_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
|
||||
@jit
|
||||
def gauss_act(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
|
||||
@jit
|
||||
def relu_act(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
|
||||
@jit
|
||||
def elu_act(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
|
||||
@jit
|
||||
def lelu_act(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
|
||||
@jit
|
||||
def selu_act(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
|
||||
@jit
|
||||
def softplus_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
|
||||
@jit
|
||||
def identity_act(z):
|
||||
return z
|
||||
|
||||
|
||||
@jit
|
||||
def clamped_act(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
|
||||
@jit
|
||||
def inv_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
|
||||
@jit
|
||||
def log_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
|
||||
@jit
|
||||
def exp_act(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
return jnp.exp(z)
|
||||
|
||||
|
||||
@jit
|
||||
def abs_act(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
|
||||
@jit
|
||||
def hat_act(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
|
||||
@jit
|
||||
def square_act(z):
|
||||
return z ** 2
|
||||
|
||||
|
||||
@jit
|
||||
def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
@@ -6,8 +6,7 @@ See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGen
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -5,8 +5,7 @@ See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .utils import EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
from jax import jit, vmap
|
||||
from jax import Array, numpy as jnp, jit, vmap
|
||||
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
@@ -4,10 +4,8 @@ Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
# from .configs import fetch_first, I_INT
|
||||
from algorithms.neat.genome.utils import fetch_first, I_INT
|
||||
|
||||
|
||||
|
||||
@@ -4,11 +4,9 @@ The calculation method is the same as the mutation operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
|
||||
"""
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp, jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
|
||||
@@ -2,8 +2,7 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
@@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
@@ -68,4 +68,4 @@ def rank_elements(array, reverse=False):
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
@@ -6,10 +6,7 @@ import jax
|
||||
from jax import jit, vmap
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import initialize_genomes
|
||||
|
||||
from algorithms.neat.population import create_next_generation, speciate, update_species
|
||||
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
|
||||
from algorithms import neat
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -32,7 +29,7 @@ class Pipeline:
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
|
||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
||||
self.species_info = np.full((self.S, 3), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0
|
||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||
@@ -47,9 +44,9 @@ class Pipeline:
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(topological_sort))
|
||||
self.forward = create_forward_function(config)
|
||||
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||
self.forward = neat.create_forward_function(config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
@@ -84,13 +81,13 @@ class Pipeline:
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
||||
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
|
||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
||||
self.jit_config)
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ contains operations on the population: creating the next generation and populati
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first, rank_elements
|
||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
Reference in New Issue
Block a user