diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py index d94a11a..71ad225 100644 --- a/algorithms/neat/__init__.py +++ b/algorithms/neat/__init__.py @@ -1,5 +1,8 @@ """ contains operations on a single genome. e.g. forward, mutate, crossover, etc. """ -from .genome import create_forward_function, topological_sort, unflatten_connections +from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes from .population import update_species, create_next_generation, speciate + +from .genome.activations import act_name2func +from .genome.aggregations import agg_name2func diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index 97f89a0..b98155f 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -2,5 +2,6 @@ from .mutate import mutate from .distance import distance from .crossover import crossover from .graph import topological_sort, check_cycles -from .utils import unflatten_connections +from .utils import unflatten_connections, I_INT, fetch_first, rank_elements from .forward import create_forward_function +from .genome import initialize_genomes diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py index 48bff71..3cd828e 100644 --- a/algorithms/neat/genome/activations.py +++ b/algorithms/neat/genome/activations.py @@ -1,105 +1,85 @@ -import jax import jax.numpy as jnp -from jax import jit -@jit def sigmoid_act(z): z = jnp.clip(z * 5, -60, 60) return 1 / (1 + jnp.exp(-z)) -@jit def tanh_act(z): z = jnp.clip(z * 2.5, -60, 60) return jnp.tanh(z) -@jit def sin_act(z): z = jnp.clip(z * 5, -60, 60) return jnp.sin(z) -@jit def gauss_act(z): z = jnp.clip(z * 5, -3.4, 3.4) return jnp.exp(-z ** 2) -@jit def relu_act(z): return jnp.maximum(z, 0) -@jit def elu_act(z): return jnp.where(z > 0, z, jnp.exp(z) - 1) -@jit def lelu_act(z): leaky = 0.005 return jnp.where(z > 0, z, leaky * z) -@jit def selu_act(z): lam = 1.0507009873554804934193349852946 alpha = 1.6732632423543772848170429916717 return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) -@jit def softplus_act(z): z = jnp.clip(z * 5, -60, 60) return 0.2 * jnp.log(1 + jnp.exp(z)) -@jit def identity_act(z): return z -@jit def clamped_act(z): return jnp.clip(z, -1, 1) -@jit def inv_act(z): z = jnp.maximum(z, 1e-7) return 1 / z -@jit def log_act(z): z = jnp.maximum(z, 1e-7) return jnp.log(z) -@jit def exp_act(z): z = jnp.clip(z, -60, 60) return jnp.exp(z) -@jit def abs_act(z): return jnp.abs(z) -@jit def hat_act(z): return jnp.maximum(0, 1 - jnp.abs(z)) -@jit def square_act(z): return z ** 2 -@jit def cube_act(z): return z ** 3 diff --git a/algorithms/neat/genome/aggregations.py b/algorithms/neat/genome/aggregations.py index a9eb8e6..81c61c9 100644 --- a/algorithms/neat/genome/aggregations.py +++ b/algorithms/neat/genome/aggregations.py @@ -1,7 +1,6 @@ import jax.numpy as jnp - def sum_agg(z): z = jnp.where(jnp.isnan(z), 0, z) return jnp.sum(z, axis=0) diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 2a02d9b..c27fa9b 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -6,8 +6,7 @@ See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGen from typing import Tuple import jax -from jax import jit, Array -from jax import numpy as jnp +from jax import jit, Array, numpy as jnp @jit diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 69e421e..4eacae4 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -5,8 +5,7 @@ See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py """ from typing import Dict -from jax import jit, vmap, Array -from jax import numpy as jnp +from jax import jit, vmap, Array, numpy as jnp from .utils import EMPTY_NODE, EMPTY_CON diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 9a4b161..efa2b06 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -1,6 +1,5 @@ import jax -from jax import Array, numpy as jnp -from jax import jit, vmap +from jax import Array, numpy as jnp, jit, vmap from .utils import I_INT diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index 746e30c..b37a12b 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -4,10 +4,8 @@ Only used in feed-forward networks. """ import jax -from jax import jit, Array -from jax import numpy as jnp +from jax import jit, Array, numpy as jnp -# from .configs import fetch_first, I_INT from algorithms.neat.genome.utils import fetch_first, I_INT diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 331dd11..3fe8a70 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -4,11 +4,9 @@ The calculation method is the same as the mutation operation in NEAT-python. See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate """ from typing import Tuple, Dict -from functools import partial import jax -from jax import numpy as jnp -from jax import jit, Array +from jax import numpy as jnp, jit, Array from .utils import fetch_random, fetch_first, I_INT, unflatten_connections from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 9ae2e95..673d662 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -2,8 +2,7 @@ from functools import partial import numpy as np import jax -from jax import numpy as jnp, Array -from jax import jit, vmap +from jax import numpy as jnp, Array, jit, vmap I_INT = np.iinfo(jnp.int32).max # infinite int EMPTY_NODE = np.full((1, 5), jnp.nan) @@ -60,6 +59,7 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) + @partial(jit, static_argnames=['reverse']) def rank_elements(array, reverse=False): """ @@ -68,4 +68,4 @@ def rank_elements(array, reverse=False): """ if not reverse: array = -array - return jnp.argsort(jnp.argsort(array)) \ No newline at end of file + return jnp.argsort(jnp.argsort(array)) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index ac2ea3f..8b1675f 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -6,10 +6,7 @@ import jax from jax import jit, vmap from configs import Configer -from algorithms.neat import initialize_genomes - -from algorithms.neat.population import create_next_generation, speciate, update_species -from algorithms.neat import unflatten_connections, topological_sort, create_forward_function +from algorithms import neat class Pipeline: @@ -32,7 +29,7 @@ class Pipeline: self.generation = 0 self.best_genome = None - self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config) + self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config) self.species_info = np.full((self.S, 3), np.nan) self.species_info[0, :] = 0, -np.inf, 0 self.idx2species = np.zeros(self.P, dtype=np.float32) @@ -47,9 +44,9 @@ class Pipeline: self.evaluate_time = 0 - self.pop_unflatten_connections = jit(vmap(unflatten_connections)) - self.pop_topological_sort = jit(vmap(topological_sort)) - self.forward = create_forward_function(config) + self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections)) + self.pop_topological_sort = jit(vmap(neat.topological_sort)) + self.forward = neat.create_forward_function(config) def ask(self): """ @@ -84,13 +81,13 @@ class Pipeline: k1, k2, self.randkey = jax.random.split(self.randkey, 3) self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \ - update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes, - self.center_cons, self.generation, self.jit_config) + neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes, + self.center_cons, self.generation, self.jit_config) - self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser, - elite_mask, self.generation, self.jit_config) + self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser, + elite_mask, self.generation, self.jit_config) - self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate( + self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate( self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation, self.jit_config) diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index f8a1f96..aa49b0f 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -4,8 +4,7 @@ contains operations on the population: creating the next generation and populati import jax from jax import jit, vmap, Array, numpy as jnp -from .genome import distance, mutate, crossover -from .genome.utils import I_INT, fetch_first, rank_elements +from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements @jit diff --git a/configs/configer.py b/configs/configer.py index ca24e61..a899bcf 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -4,8 +4,7 @@ import configparser import numpy as np -from algorithms.neat.genome.activations import act_name2func -from algorithms.neat.genome.aggregations import agg_name2func +from algorithms.neat import act_name2func, agg_name2func # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ diff --git a/examples/jax_playground.py b/examples/jax_playground.py index f3dd308..fc22005 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,8 +1,8 @@ from functools import partial -import jax from jax import numpy as jnp, jit + @partial(jit, static_argnames=['reverse']) def rank_element(array, reverse=False): """ @@ -14,5 +14,5 @@ def rank_element(array, reverse=False): return jnp.argsort(jnp.argsort(array)) -a = jnp.array([1 ,5, 3, 5, 2, 1, 0]) -print(rank_element(a, reverse=True)) \ No newline at end of file +a = jnp.array([1, 5, 3, 5, 2, 1, 0]) +print(rank_element(a, reverse=True))