From f6dcb97df8bce0b5b8df146a91405a474c76ef61 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 1 Jul 2023 13:36:19 +0800 Subject: [PATCH] modify method cal_spawn_numbers spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate'] --- algorithms/neat/__init__.py | 2 +- algorithms/neat/genome/forward.py | 2 +- algorithms/neat/population.py | 30 +++++++++++++++++++++++- configs/configer.py | 3 ++- configs/default_config.ini | 5 ++-- examples/xor.py | 4 +++- pipeline.py | 39 ++++++++++++++++++++----------- 7 files changed, 64 insertions(+), 21 deletions(-) diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py index 445a48a..f8d364b 100644 --- a/algorithms/neat/__init__.py +++ b/algorithms/neat/__init__.py @@ -2,7 +2,7 @@ contains operations on a single genome. e.g. forward, mutate, crossover, etc. """ from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes -from .population import update_species, create_next_generation, speciate +from .population import update_species, create_next_generation, speciate, tell from .genome.activations import act_name2func from .genome.aggregations import agg_name2func diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index efa2b06..a0f26b7 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -100,4 +100,4 @@ def create_forward_function(config): elif config['forward_way'] == 'common': return jit(common_forward) - return forward + return jit(forward) diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index a521a9e..44740e7 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements +@jit +def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, + jit_config): + + generation += 1 + + k1, k2, randkey = jax.random.split(randkey, 3) + + species_info, center_nodes, center_cons, winner, loser, elite_mask = \ + update_species(k1, fitness, species_info, idx2species, center_nodes, + center_cons, generation, jit_config) + + pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser, + elite_mask, generation, jit_config) + + idx2species, center_nodes, center_cons, species_info = speciate( + pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, + jit_config) + + return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation + + @jit def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): """ @@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config): spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 - spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member + target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member + + # Avoid too much variation of numbers in a species + previous_size = species_info[:, 3].astype(jnp.int32) + spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate'] + + spawn_number = spawn_number.astype(jnp.int32) # must control the sum of spawn_number to be equal to pop_size error = jit_config['pop_size'] - jnp.sum(spawn_number) diff --git a/configs/configer.py b/configs/configer.py index a899bcf..d226eb6 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -44,7 +44,8 @@ jit_config_keys = [ "pop_size", "genome_elitism", "survival_threshold", - "species_elitism" + "species_elitism", + "spawn_number_move_rate" ] diff --git a/configs/default_config.ini b/configs/default_config.ini index b85cba9..0860532 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -2,7 +2,7 @@ num_inputs = 2 num_outputs = 1 init_maximum_nodes = 50 -init_maximum_connections = 200 +init_maximum_connections = 50 init_maximum_species = 10 expand_coe = 1.5 pre_expand_threshold = 0.75 @@ -13,7 +13,7 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 1000 fitness_criterion = "max" -pop_size = 150 +pop_size = 2000 [genome] compatibility_disjoint = 1.0 @@ -31,6 +31,7 @@ max_stagnation = 15 genome_elitism = 2 survival_threshold = 0.2 min_species_size = 1 +spawn_number_move_rate = 0.5 [gene-bias] bias_init_mean = 0.0 diff --git a/examples/xor.py b/examples/xor.py index fc5369e..228978a 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,3 +1,4 @@ +import jax import numpy as np from configs import Configer @@ -14,8 +15,9 @@ def evaluate(forward_func): :return: """ outs = forward_func(xor_inputs) + outs = jax.device_get(outs) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return np.array(fitnesses) # returns a list + return fitnesses def main(): diff --git a/pipeline.py b/pipeline.py index 8e9bf72..bf749db 100644 --- a/pipeline.py +++ b/pipeline.py @@ -48,6 +48,26 @@ class Pipeline: self.pop_topological_sort = jit(vmap(neat.topological_sort)) self.forward = neat.create_forward_function(config) + # fitness_lower = np.zeros(self.P, dtype=np.float32) + # randkey_lower = np.zeros(2, dtype=np.uint32) + # pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32) + # pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32) + # species_info_lower = np.zeros((self.S, 4), dtype=np.float32) + # idx2species_lower = np.zeros(self.P, dtype=np.float32) + # center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32) + # center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32) + # + # self.tell_func = jit(neat.tell).lower(fitness_lower, + # randkey_lower, + # pop_nodes_lower, + # pop_cons_lower, + # species_info_lower, + # idx2species_lower, + # center_nodes_lower, + # center_cons_lower, + # 0, + # self.jit_config).compile() + def ask(self): """ Creates a function that receives a genome and returns a forward function. @@ -75,21 +95,12 @@ class Pipeline: assert self.config['forward_way'] == 'common' return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) - def tell(self, fitnesses): - self.generation += 1 + def tell(self, fitness): - k1, k2, self.randkey = jax.random.split(self.randkey, 3) - - self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \ - neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes, - self.center_cons, self.generation, self.jit_config) - - self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser, - elite_mask, self.generation, self.jit_config) - - self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate( - self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation, - self.jit_config) + self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ + self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons, + self.species_info, self.idx2species, self.center_nodes, + self.center_cons, self.generation, self.jit_config) def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']):