From 114ff2b0ccd7de644b930dac52697dbd00a930c5 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 27 Jun 2023 18:47:47 +0800 Subject: [PATCH] modifying --- {neat/genome/debug => algorithms}/__init__.py | 0 algorithms/neat/__init__.py | 6 + {neat => algorithms/neat}/genome/__init__.py | 0 .../neat}/genome/activations.py | 0 .../neat}/genome/aggregations.py | 0 {neat => algorithms/neat}/genome/crossover.py | 0 algorithms/neat/genome/debug/__init__.py | 0 .../neat}/genome/debug/tools.py | 0 {neat => algorithms/neat}/genome/distance.py | 0 {neat => algorithms/neat}/genome/forward.py | 5 +- {neat => algorithms/neat}/genome/genome.py | 0 {neat => algorithms/neat}/genome/graph.py | 4 +- {neat => algorithms/neat}/genome/mutate.py | 0 {neat => algorithms/neat}/genome/utils.py | 13 +- algorithms/neat/jit_species.py | 160 ++++++++++++++++++ {neat => algorithms/neat}/operations.py | 7 +- {neat => algorithms/neat}/species.py | 1 - configs/configer.py | 9 +- configs/default_config.ini | 9 +- examples/a.py | 55 ------ examples/evox_test.py | 26 +++ examples/jax_playground.py | 37 ++-- examples/jit_xor.py | 28 +++ examples/xor.py | 4 +- ...function_factory.py => function_factory.py | 19 +-- jit_pipeline.py | 159 +++++++++++++++++ neat/__init__.py | 3 - neat/pipeline.py => pipeline.py | 29 ++-- 28 files changed, 451 insertions(+), 123 deletions(-) rename {neat/genome/debug => algorithms}/__init__.py (100%) create mode 100644 algorithms/neat/__init__.py rename {neat => algorithms/neat}/genome/__init__.py (100%) rename {neat => algorithms/neat}/genome/activations.py (100%) rename {neat => algorithms/neat}/genome/aggregations.py (100%) rename {neat => algorithms/neat}/genome/crossover.py (100%) create mode 100644 algorithms/neat/genome/debug/__init__.py rename {neat => algorithms/neat}/genome/debug/tools.py (100%) rename {neat => algorithms/neat}/genome/distance.py (100%) rename {neat => algorithms/neat}/genome/forward.py (97%) rename {neat => algorithms/neat}/genome/genome.py (100%) rename {neat => algorithms/neat}/genome/graph.py (97%) rename {neat => algorithms/neat}/genome/mutate.py (100%) rename {neat => algorithms/neat}/genome/utils.py (85%) create mode 100644 algorithms/neat/jit_species.py rename {neat => algorithms/neat}/operations.py (98%) rename {neat => algorithms/neat}/species.py (99%) delete mode 100644 examples/a.py create mode 100644 examples/evox_test.py create mode 100644 examples/jit_xor.py rename neat/function_factory.py => function_factory.py (89%) create mode 100644 jit_pipeline.py delete mode 100644 neat/__init__.py rename neat/pipeline.py => pipeline.py (88%) diff --git a/neat/genome/debug/__init__.py b/algorithms/__init__.py similarity index 100% rename from neat/genome/debug/__init__.py rename to algorithms/__init__.py diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py new file mode 100644 index 0000000..f1c4384 --- /dev/null +++ b/algorithms/neat/__init__.py @@ -0,0 +1,6 @@ +""" +contains operations on a single genome. e.g. forward, mutate, crossover, etc. +""" +from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single +from .operations import create_next_generation_then_speciate +from .species import SpeciesController diff --git a/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py similarity index 100% rename from neat/genome/__init__.py rename to algorithms/neat/genome/__init__.py diff --git a/neat/genome/activations.py b/algorithms/neat/genome/activations.py similarity index 100% rename from neat/genome/activations.py rename to algorithms/neat/genome/activations.py diff --git a/neat/genome/aggregations.py b/algorithms/neat/genome/aggregations.py similarity index 100% rename from neat/genome/aggregations.py rename to algorithms/neat/genome/aggregations.py diff --git a/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py similarity index 100% rename from neat/genome/crossover.py rename to algorithms/neat/genome/crossover.py diff --git a/algorithms/neat/genome/debug/__init__.py b/algorithms/neat/genome/debug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/neat/genome/debug/tools.py b/algorithms/neat/genome/debug/tools.py similarity index 100% rename from neat/genome/debug/tools.py rename to algorithms/neat/genome/debug/tools.py diff --git a/neat/genome/distance.py b/algorithms/neat/genome/distance.py similarity index 100% rename from neat/genome/distance.py rename to algorithms/neat/genome/distance.py diff --git a/neat/genome/forward.py b/algorithms/neat/genome/forward.py similarity index 97% rename from neat/genome/forward.py rename to algorithms/neat/genome/forward.py index 9eeb7e5..40f7106 100644 --- a/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -5,8 +5,11 @@ from jax import jit, vmap from .utils import I_INT -# TODO: enabled information doesn't influence forward. That is wrong! def create_forward(config): + """ + meta method to create forward function + """ + def act(idx, z): """ calculate activation function for each node diff --git a/neat/genome/genome.py b/algorithms/neat/genome/genome.py similarity index 100% rename from neat/genome/genome.py rename to algorithms/neat/genome/genome.py diff --git a/neat/genome/graph.py b/algorithms/neat/genome/graph.py similarity index 97% rename from neat/genome/graph.py rename to algorithms/neat/genome/graph.py index 6741cba..746e30c 100644 --- a/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -4,11 +4,11 @@ Only used in feed-forward networks. """ import jax -from jax import jit, vmap, Array +from jax import jit, Array from jax import numpy as jnp # from .configs import fetch_first, I_INT -from neat.genome.utils import fetch_first, I_INT, unflatten_connections +from algorithms.neat.genome.utils import fetch_first, I_INT @jit diff --git a/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py similarity index 100% rename from neat/genome/mutate.py rename to algorithms/neat/genome/mutate.py diff --git a/neat/genome/utils.py b/algorithms/neat/genome/utils.py similarity index 85% rename from neat/genome/utils.py rename to algorithms/neat/genome/utils.py index 93fadca..4a984c3 100644 --- a/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import jax from jax import numpy as jnp, Array @@ -30,6 +32,7 @@ def unflatten_connections(nodes: Array, cons: Array): return res + def key_to_indices(key, keys): return fetch_first(key == keys) @@ -56,4 +59,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) - +@partial(jit, static_argnames=['reverse']) +def rank_elements(array, reverse=False): + """ + rank the element in the array. + if reverse is True, the rank is from large to small. + """ + if reverse: + array = -array + return jnp.argsort(jnp.argsort(array)) \ No newline at end of file diff --git a/algorithms/neat/jit_species.py b/algorithms/neat/jit_species.py new file mode 100644 index 0000000..a4f4f44 --- /dev/null +++ b/algorithms/neat/jit_species.py @@ -0,0 +1,160 @@ +from functools import partial + +import jax +from jax import jit, numpy as jnp, vmap + +from .genome.utils import rank_elements + + +@jit +def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): + """ + args: + randkey: random key + fitness: Array[(pop_size,), float], the fitness of each individual + species_keys: Array[(species_size, 3), float], the information of each species + [species_key, best_score, last_update] + idx2species: Array[(pop_size,), int], map the individual to its species + center_nodes: Array[(species_size, N, 4), float], the center nodes of each species + center_cons: Array[(species_size, C, 4), float], the center connections of each species + generation: int, current generation + jit_config: Dict, the configuration of jit functions + """ + + # update the fitness of each species + species_fitness = update_species_fitness(species_info, idx2species, fitness) + + # stagnation species + species_fitness, species_info, center_nodes, center_cons = \ + stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config) + + # sort species_info by their fitness. (push nan to the end) + sort_indices = jnp.argsort(species_fitness)[::-1] + species_info = species_info[sort_indices] + center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices] + + # decide the number of members of each species by their fitness + spawn_number = cal_spawn_numbers(species_info, jit_config) + + # crossover info + winner, loser, elite_mask = \ + create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config) + + jax.debug.print("{}, {}", fitness, winner) + jax.debug.print("{}", fitness[winner]) + + return species_info, center_nodes, center_cons, winner, loser, elite_mask + + +def update_species_fitness(species_info, idx2species, fitness): + """ + obtain the fitness of the species by the fitness of each individual. + use max criterion. + """ + + def aux_func(idx): + species_key = species_info[idx, 0] + s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf) + f = jnp.max(s_fitness) + return f + + return vmap(aux_func)(jnp.arange(species_info.shape[0])) + + +def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config): + """ + stagnation species. + those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. + elitism species never stagnation + """ + + def aux_func(idx): + s_fitness = species_fitness[idx] + species_key, best_score, last_update = species_info[idx] + # stagnation condition + return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation']) + + st = vmap(aux_func)(jnp.arange(species_info.shape[0])) + + # elite species will not be stagnation + species_rank = rank_elements(species_fitness) + st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation + + # set stagnation species to nan + species_info = jnp.where(st[:, None], jnp.nan, species_info) + center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes) + center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons) + species_fitness = jnp.where(st, jnp.nan, species_fitness) + + return species_fitness, species_info, center_nodes, center_cons + + +def cal_spawn_numbers(species_info, jit_config): + """ + decide the number of members of each species by their fitness rank. + the species with higher fitness will have more members + Linear ranking selection + e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] + """ + + is_species_valid = ~jnp.isnan(species_info[:, 0]) + valid_species_num = jnp.sum(is_species_valid) + denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + + rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1] + spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] + spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + + spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member + + # must control the sum of spawn_number to be equal to pop_size + error = jit_config['pop_size'] - jnp.sum(spawn_number) + spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number + + return spawn_number + + +def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config): + + species_size = species_info.shape[0] + pop_size = fitness.shape[0] + s_idx = jnp.arange(species_size) + p_idx = jnp.arange(pop_size) + + def aux_func(key, idx): + members = idx2species == species_info[idx, 0] + members_num = jnp.sum(members) + + members_fitness = jnp.where(members, fitness, jnp.nan) + sorted_member_indices = jnp.argsort(members_fitness)[::-1] + + elite_size = jit_config['genome_elitism'] + survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32) + + select_pro = (p_idx < survive_size) / survive_size + fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) + + # elite + fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) + ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) + elite = jnp.where(p_idx < elite_size, True, False) + return fa, ma, elite + + fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) + + spawn_number_cum = jnp.cumsum(spawn_number) + + def aux_func(idx): + loc = jnp.argmax(idx < spawn_number_cum) + + # elite genomes are at the beginning of the species + idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) + return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + + part1, part2, elite_mask = vmap(aux_func)(p_idx) + + is_part1_win = fitness[part1] >= fitness[part2] + winner = jnp.where(is_part1_win, part1, part2) + loser = jnp.where(is_part1_win, part2, part1) + + return winner, loser, elite_mask diff --git a/neat/operations.py b/algorithms/neat/operations.py similarity index 98% rename from neat/operations.py rename to algorithms/neat/operations.py index 7984e2a..bb0a758 100644 --- a/neat/operations.py +++ b/algorithms/neat/operations.py @@ -1,13 +1,8 @@ """ contains operations on the population: creating the next generation and population speciation. """ -from functools import partial - import jax -import jax.numpy as jnp -from jax import jit, vmap - -from jax import Array +from jax import jit, vmap, Array, numpy as jnp from .genome import distance, mutate, crossover from .genome.utils import I_INT, fetch_first diff --git a/neat/species.py b/algorithms/neat/species.py similarity index 99% rename from neat/species.py rename to algorithms/neat/species.py index ce7415a..32952a4 100644 --- a/neat/species.py +++ b/algorithms/neat/species.py @@ -8,7 +8,6 @@ See """ from typing import List, Tuple, Dict -from itertools import count import numpy as np from numpy.typing import NDArray diff --git a/configs/configer.py b/configs/configer.py index 9118a38..ca24e61 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -4,8 +4,8 @@ import configparser import numpy as np -from neat.genome.activations import act_name2func -from neat.genome.aggregations import agg_name2func +from algorithms.neat.genome.activations import act_name2func +from algorithms.neat.genome.aggregations import agg_name2func # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ @@ -41,6 +41,11 @@ jit_config_keys = [ "weight_mutate_rate", "weight_replace_rate", "enable_mutate_rate", + "max_stagnation", + "pop_size", + "genome_elitism", + "survival_threshold", + "species_elitism" ] diff --git a/configs/default_config.ini b/configs/default_config.ini index d4bbcda..29810c1 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -1,10 +1,11 @@ [basic] num_inputs = 2 num_outputs = 1 -init_maximum_nodes = 50 -init_maximum_connections = 50 +init_maximum_nodes = 200 +init_maximum_connections = 200 init_maximum_species = 10 -expand_coe = 2.0 +expand_coe = 1.5 +pre_expand_threshold = 0.75 forward_way = "pop" batch_size = 4 @@ -12,7 +13,7 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 100 fitness_criterion = "max" -pop_size = 15000 +pop_size = 150 [genome] compatibility_disjoint = 1.0 diff --git a/examples/a.py b/examples/a.py deleted file mode 100644 index c7138d9..0000000 --- a/examples/a.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np - -import jax.numpy as jnp -import jax - -a = {1:2, 2:3, 4:5} -print(a.values()) - -a = jnp.array([1, 0, 1, 0, np.nan]) -b = jnp.array([1, 1, 1, 1, 1]) -c = jnp.array([1, 1, 1, 1, 1]) - -full = jnp.array([ - [1, 1, 1], - [0, 1, 1], - [1, 1, 1], - [0, 1, 1], -]) - -print(jnp.column_stack([a[:, None], b[:, None], c[:, None]])) - -aux0 = full[:, 0, None] -aux1 = full[:, 1, None] - -print(aux0, aux0.shape) - -print(jnp.concatenate([aux0, aux1], axis=1)) - -f_a = jnp.array([False, False, True, True]) -f_b = jnp.array([True, False, False, False]) - -print(jnp.logical_and(f_a, f_b)) -print(f_a & f_b) - -print(f_a + jnp.nan * 0.0) -print(f_a + 1 * 0.0) - - -@jax.jit -def main(): - return func('happy') + func('sad') - - -def func(x): - if x == 'happy': - return 1 - else: - return 2 - -a = jnp.zeros((3, 3)) -print(a.dtype) - -c = None -b = 1 or c -print(b) \ No newline at end of file diff --git a/examples/evox_test.py b/examples/evox_test.py new file mode 100644 index 0000000..54150c1 --- /dev/null +++ b/examples/evox_test.py @@ -0,0 +1,26 @@ +import jax +from jax import numpy as jnp + +from evox import algorithms, problems, pipelines +from evox.monitors import StdSOMonitor + +monitor = StdSOMonitor() + +pso = algorithms.PSO( + lb=jnp.full(shape=(2,), fill_value=-32), + ub=jnp.full(shape=(2,), fill_value=32), + pop_size=100, +) + +ackley = problems.classic.Ackley() + +pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit) + +key = jax.random.PRNGKey(42) +state = pipeline.init(key) + +# run the pipeline for 100 steps +for i in range(100): + state = pipeline.step(state) + +print(monitor.get_min_fitness()) \ No newline at end of file diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 0969908..f3dd308 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,27 +1,18 @@ -import numpy as np -from jax import jit +from functools import partial -from configs import Configer -from neat.pipeline import Pipeline +import jax +from jax import numpy as jnp, jit -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - -def main(): - config = Configer.load_config("xor.ini") - print(config) - pipeline = Pipeline(config) - forward_func = pipeline.ask() - # inputs = np.tile(xor_inputs, (150, 1, 1)) - outputs = forward_func(xor_inputs) - print(outputs) +@partial(jit, static_argnames=['reverse']) +def rank_element(array, reverse=False): + """ + rank the element in the array. + if reverse is True, the rank is from large to small. + """ + if reverse: + array = -array + return jnp.argsort(jnp.argsort(array)) - -@jit -def f(x, jit_config): - return x + jit_config["bias_mutate_rate"] - - -if __name__ == '__main__': - main() +a = jnp.array([1 ,5, 3, 5, 2, 1, 0]) +print(rank_element(a, reverse=True)) \ No newline at end of file diff --git a/examples/jit_xor.py b/examples/jit_xor.py new file mode 100644 index 0000000..bd92570 --- /dev/null +++ b/examples/jit_xor.py @@ -0,0 +1,28 @@ +import numpy as np + +from configs import Configer +from jit_pipeline import Pipeline + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return np.array(fitnesses) # returns a list + + +def main(): + config = Configer.load_config("xor.ini") + pipeline = Pipeline(config, seed=6) + nodes, cons = pipeline.auto_run(evaluate) + print(nodes, cons) + + +if __name__ == '__main__': + main() diff --git a/examples/xor.py b/examples/xor.py index b509c76..f91583a 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,7 +1,7 @@ import numpy as np from configs import Configer -from neat.pipeline import Pipeline +from pipeline import Pipeline xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -21,6 +21,8 @@ def main(): config = Configer.load_config("xor.ini") pipeline = Pipeline(config, seed=6) nodes, cons = pipeline.auto_run(evaluate) + print(nodes, cons) + if __name__ == '__main__': main() diff --git a/neat/function_factory.py b/function_factory.py similarity index 89% rename from neat/function_factory.py rename to function_factory.py index 953cd98..c05cd2d 100644 --- a/neat/function_factory.py +++ b/function_factory.py @@ -1,8 +1,9 @@ import numpy as np from jax import jit, vmap -from .genome import create_forward, topological_sort, unflatten_connections -from .operations import create_next_generation_then_speciate +from algorithms.neat import create_forward, topological_sort, \ + unflatten_connections, create_next_generation_then_speciate + def hash_symbols(symbols): return symbols['P'], symbols['N'], symbols['C'], symbols['S'] @@ -32,7 +33,6 @@ class FunctionFactory: # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) - self.function_info = { "pop_unflatten_connections": { 'func': vmap(unflatten_connections), @@ -54,7 +54,7 @@ class FunctionFactory: 'func': batch_forward, 'lowers': [ {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, - {'shape': ('N', ), 'type': np.int32}, + {'shape': ('N',), 'type': np.int32}, {'shape': ('N', 5), 'type': np.float32}, {'shape': (2, 'N', 'N'), 'type': np.float32} ] @@ -83,23 +83,22 @@ class FunctionFactory: 'create_next_generation_then_speciate': { 'func': create_next_generation_then_speciate, 'lowers': [ - {'shape': (2, ), 'type': np.uint32}, # rand_key + {'shape': (2,), 'type': np.uint32}, # rand_key {'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes {'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons - {'shape': ('P', ), 'type': np.int32}, # winner - {'shape': ('P', ), 'type': np.int32}, # loser - {'shape': ('P', ), 'type': bool}, # elite_mask + {'shape': ('P',), 'type': np.int32}, # winner + {'shape': ('P',), 'type': np.int32}, # loser + {'shape': ('P',), 'type': bool}, # elite_mask {'shape': ('P',), 'type': np.int32}, # new_node_keys {'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes {'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons - {'shape': ('S', ), 'type': np.int32}, # species_keys + {'shape': ('S',), 'type': np.int32}, # species_keys {'shape': (), 'type': np.int32}, # new_species_key_start "jit_config" ] } } - def get(self, name, symbols): if (name, hash_symbols(symbols)) not in self.func_dict: self.compile(name, symbols) diff --git a/jit_pipeline.py b/jit_pipeline.py new file mode 100644 index 0000000..8bb03b0 --- /dev/null +++ b/jit_pipeline.py @@ -0,0 +1,159 @@ +import time +from typing import Union, Callable + +import numpy as np +import jax + +from configs import Configer +from function_factory import FunctionFactory +from algorithms.neat import initialize_genomes, expand, expand_single + +from algorithms.neat.jit_species import update_species +from algorithms.neat.operations import create_next_generation_then_speciate + + +class Pipeline: + """ + Neat algorithm pipeline. + """ + + def __init__(self, config, function_factory=None, seed=42): + self.randkey = jax.random.PRNGKey(seed) + np.random.seed(seed) + + self.config = config # global config + self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions + self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config) + + self.symbols = { + 'P': self.config['pop_size'], + 'N': self.config['init_maximum_nodes'], + 'C': self.config['init_maximum_connections'], + 'S': self.config['init_maximum_species'], + } + + self.generation = 0 + self.best_genome = None + + self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config) + self.species_info = np.full((self.symbols['S'], 3), np.nan) + self.species_info[0, :] = 0, -np.inf, 0 + self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32) + self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan) + self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan) + self.center_nodes[0, :, :] = self.pop_nodes[0, :, :] + self.center_cons[0, :, :] = self.pop_cons[0, :, :] + + self.best_fitness = float('-inf') + self.best_genome = None + self.generation_timestamp = time.time() + + self.evaluate_time = 0 + print(self.config) + + def ask(self): + """ + Creates a function that receives a genome and returns a forward function. + There are 3 types of config['forward_way']: {'single', 'pop', 'common'} + + single: + Create pop_size number of forward functions. + Each function receive (batch_size, input_size) and returns (batch_size, output_size) + e.g. RL task + + pop: + Create a single forward function, which use only once calculation for the population. + The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size) + + common: + Special case of pop. The population has the same inputs. + The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size) + e.g. numerical regression; Hyper-NEAT + + """ + u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons) + pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons) + + if self.config['forward_way'] == 'single': + forward_funcs = [] + for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons): + func = lambda x: self.get_func('forward')(x, seq, nodes, cons) + forward_funcs.append(func) + return forward_funcs + + elif self.config['forward_way'] == 'pop': + func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) + return func + + elif self.config['forward_way'] == 'common': + func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) + return func + + else: + raise NotImplementedError + + def tell(self, fitnesses): + self.generation += 1 + + species_info, center_nodes, center_cons, winner, loser, elite_mask = \ + update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes, + self.center_cons, self.generation, self.jit_config) + + # node keys to be used in the mutation process + new_node_keys = np.arange(self.generation * self.config['pop_size'], + self.generation * self.config['pop_size'] + self.config['pop_size']) + + # create the next generation and then speciate the population + self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ + create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes, + center_cons, species_keys, species_key_start, self.jit_config) + + # carry data to cpu + self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ + jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys]) + + # update randkey + self.randkey = jax.random.split(self.randkey)[0] + + def get_func(self, name): + return self.function_factory.get(name, self.symbols) + + def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): + for _ in range(self.config['generation_limit']): + forward_func = self.ask() + + tic = time.time() + fitnesses = fitness_func(forward_func) + self.evaluate_time += time.time() - tic + + assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" + + if analysis is not None: + if analysis == "default": + self.default_analysis(fitnesses) + else: + assert callable(analysis), f"What the fuck you passed in? A {analysis}?" + analysis(fitnesses) + + if max(fitnesses) >= self.config['fitness_threshold']: + print("Fitness limit reached!") + return self.best_genome + + self.tell(fitnesses) + print("Generation limit reached!") + return self.best_genome + + def default_analysis(self, fitnesses): + max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) + + new_timestamp = time.time() + cost_time = new_timestamp - self.generation_timestamp + self.generation_timestamp = new_timestamp + + max_idx = np.argmax(fitnesses) + if fitnesses[max_idx] > self.best_fitness: + self.best_fitness = fitnesses[max_idx] + self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) + + print(f"Generation: {self.generation}", + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") diff --git a/neat/__init__.py b/neat/__init__.py deleted file mode 100644 index 8f8da22..0000000 --- a/neat/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -contains operations on a single genome. e.g. forward, mutate, crossover, etc. -""" \ No newline at end of file diff --git a/neat/pipeline.py b/pipeline.py similarity index 88% rename from neat/pipeline.py rename to pipeline.py index 45c92f4..2b6bf37 100644 --- a/neat/pipeline.py +++ b/pipeline.py @@ -5,9 +5,8 @@ import numpy as np import jax from configs import Configer -from .genome import initialize_genomes, expand, expand_single -from .function_factory import FunctionFactory -from .species import SpeciesController +from function_factory import FunctionFactory +from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController class Pipeline: @@ -119,25 +118,27 @@ class Pipeline: when the maximum node number >= N or the maximum connection number of >= C the population will expand """ - changed = False + # analysis nodes pop_node_keys = self.pop_nodes[:, :, 0] pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1) max_node_size = np.max(pop_node_sizes) - if max_node_size >= self.symbols['N']: - self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe']) - print(f"node expand to {self.symbols['N']}!") - changed = True + # analysis connections pop_con_keys = self.pop_cons[:, :, 0] pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) max_con_size = np.max(pop_node_sizes) - if max_con_size >= self.symbols['C']: - self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe']) - print(f"connection expand to {self.symbols['C']}!") - changed = True - if changed: + # expand if needed + if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']: + if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']: + self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe']) + print(f"pre node expand to {self.symbols['N']}!") + + if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']: + self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe']) + print(f"pre connection expand to {self.symbols['C']}!") + self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C']) # don't forget to expand representation genome in species for s in self.species_controller.species.values(): @@ -160,7 +161,7 @@ class Pipeline: if analysis == "default": self.default_analysis(fitnesses) else: - assert callable(analysis), f"What the fuck you passed in? A {analysis}?" + assert callable(analysis), f"Callable is needed hereπŸ˜…πŸ˜…πŸ˜… A {analysis}?" analysis(fitnesses) if max(fitnesses) >= self.config['fitness_threshold']: