From c4d34e877b8480322bff3d934feabb23cc527480 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 2 Jul 2023 22:15:26 +0800 Subject: [PATCH] perfect! fix bug about jax auto recompile add task xor-3d --- algorithms/neat/__init__.py | 2 +- algorithms/neat/genome/genome.py | 6 +- algorithms/neat/population.py | 142 ++++++++++++++++++++++--------- configs/configer.py | 3 +- configs/default_config.ini | 13 ++- examples/debug.py | 2 - examples/xor.ini | 2 +- examples/xor.py | 7 +- examples/xor3d.ini | 47 ++++++++++ examples/xor3d.py | 31 +++++++ pipeline.py | 83 +++++++++--------- 11 files changed, 234 insertions(+), 104 deletions(-) create mode 100644 examples/xor3d.ini create mode 100644 examples/xor3d.py diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py index f8d364b..37f1924 100644 --- a/algorithms/neat/__init__.py +++ b/algorithms/neat/__init__.py @@ -2,7 +2,7 @@ contains operations on a single genome. e.g. forward, mutate, crossover, etc. """ from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes -from .population import update_species, create_next_generation, speciate, tell +from .population import update_species, create_next_generation, speciate, tell, initialize from .genome.activations import act_name2func from .genome.aggregations import agg_name2func diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 2679289..67aac68 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \ f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!" - pop_nodes = np.full((config['pop_size'], N, 5), np.nan) - pop_cons = np.full((config['pop_size'], C, 4), np.nan) + pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32) + pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32) input_idx = config['input_idx'] output_idx = config['output_idx'] @@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: pop_cons[:, :p, 0] = grid_a pop_cons[:, :p, 1] = grid_b pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'], - size=(config['pop_size'], p)) + size=(config['pop_size'], p)) pop_cons[:, :p, 3] = 1 return pop_nodes, pop_cons diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index 0119a28..61cab13 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -1,20 +1,88 @@ """ Contains operations on the population: creating the next generation and population speciation. -These im..... +The value tuple (P, N, C, S) is determined when the algorithm is initialized. + P: population size + N: maximum number of nodes in any genome + C: maximum number of connections in any genome + S: maximum number of species in NEAT + +These arrays are used in the algorithm: + fitness: Array[(P,), float], the fitness of each individual + randkey: Array[2, uint], the random key + pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg] + pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled] + species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count] + idx2species: Array[(P,), float], map the individual to its species keys + center_nodes: Array[(S, N, 5), float], the center nodes of each species + center_cons: Array[(S, C, 4), float], the center connections of each species + generation: int, the current generation + next_node_key: float, the next of the next node + next_species_key: float, the next of the next species + jit_config: Configer, the config used in jit-able functions """ # TODO: Complete python doc +import numpy as np import jax from jax import jit, vmap, Array, numpy as jnp -from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements +from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements +def initialize(config): + """ + initialize the states of NEAT. + """ + + P = config['pop_size'] + N = config['maximum_nodes'] + C = config['maximum_connections'] + S = config['maximum_species'] + + randkey = jax.random.PRNGKey(config['random_seed']) + np.random.seed(config['random_seed']) + pop_nodes, pop_cons = initialize_genomes(N, C, config) + species_info = np.full((S, 4), np.nan, dtype=np.float32) + species_info[0, :] = 0, -np.inf, 0, P + idx2species = np.zeros(P, dtype=np.float32) + center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32) + center_cons = np.full((S, C, 4), np.nan, dtype=np.float32) + center_nodes[0, :, :] = pop_nodes[0, :, :] + center_cons[0, :, :] = pop_cons[0, :, :] + generation = np.asarray(0, dtype=np.int32) + next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32) + next_species_key = np.asarray(1, dtype=np.float32) + + return jax.device_put([ + randkey, + pop_nodes, + pop_cons, + species_info, + idx2species, + center_nodes, + center_cons, + generation, + next_node_key, + next_species_key, + ]) + @jit -def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, +def tell(fitness, + randkey, + pop_nodes, + pop_cons, + species_info, + idx2species, + center_nodes, + center_cons, + generation, + next_node_key, + next_species_key, jit_config): - + """ + Main update function in NEAT. + """ generation += 1 k1, k2, randkey = jax.random.split(randkey, 3) @@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente update_species(k1, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config) + pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser, + elite_mask, next_node_key, jit_config) - pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser, - elite_mask, generation, jit_config) + idx2species, center_nodes, center_cons, species_info, next_species_key = speciate( + pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config) - idx2species, center_nodes, center_cons, species_info = speciate( - pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, - jit_config) + return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key - return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation - - -@jit def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): """ args: @@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn return winner, loser, elite_mask -@jit -def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config): +def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config): # prepare random keys pop_size = pop_nodes.shape[0] - new_node_keys = jnp.arange(pop_size) + generation * pop_size + new_node_keys = jnp.arange(pop_size) + next_node_key k1, k2 = jax.random.split(rand_key, 2) crossover_rand_keys = jax.random.split(k1, pop_size) @@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) - return pop_nodes, pop_cons + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return pop_nodes, pop_cons, next_node_key -@jit -def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config): +def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config): """ args: pop_nodes: (pop_size, N, 5) @@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species # the distance between genomes to its center genomes - o2c_distances = jnp.full((pop_size, ), jnp.inf) + o2c_distances = jnp.full((pop_size,), jnp.inf) # step 1: find new centers def cond_func(carry): @@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener # part 2: assign members to each species def cond_func(carry): - i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key + i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key # jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si) current_species_existed = ~jnp.isnan(si[i, 0]) not_all_assigned = jnp.any(jnp.isnan(i2s)) not_reach_species_upper_bounds = i < species_size - return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds) + return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) def body_func(carry): - i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - _, i2s, scn, scc, si, o2c, ck = jax.lax.cond( + _, i2s, scn, scc, si, o2c, nsk = jax.lax.cond( jnp.isnan(si[i, 0]), # whether the current species is existing or not create_new_species, # if not existing, create a new specie update_exist_specie, # if existing, update the specie - (i, i2s, cn, cc, si, o2c, ck) + (i, i2s, cn, cc, si, o2c, nsk) ) - return i + 1, i2s, scn, scc, si, o2c, ck + return i + 1, i2s, scn, scc, si, o2c, nsk def create_new_species(carry): - i, i2s, cn, cc, si, o2c, ck = carry + i, i2s, cn, cc, si, o2c, nsk = carry # pick the first one who has not been assigned to any species idx = fetch_first(jnp.isnan(i2s)) # assign it to the new species # [key, best score, last update generation, members_count] - si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0])) - i2s = i2s.at[idx].set(ck) + si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0])) + i2s = i2s.at[idx].set(nsk) o2c = o2c.at[idx].set(0) # update center genomes @@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) # when a new species is created, it needs to be updated, thus do not change i - return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key + return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key def update_exist_specie(carry): - i, i2s, cn, cc, si, o2c, ck = carry + i, i2s, cn, cc, si, o2c, nsk = carry i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) # turn to next species - return i + 1, i2s, cn, cc, si, o2c, ck + return i + 1, i2s, cn, cc, si, o2c, nsk def speciate_by_threshold(carry): i, i2s, cn, cc, si, o2c = carry @@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener return i2s, o2c - species_keys = species_info[:, 0] - current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1 - - # update idx2specie - _, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop( + _, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop( cond_func, body_func, - (0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key) + (0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key) ) # if there are still some pop genomes not assigned to any species, add them to the last genome @@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener species_member_counts = vmap(count_members)(jnp.arange(species_size)) species_info = species_info.at[:, 3].set(species_member_counts) - return idx2specie, center_nodes, center_cons, species_info + return idx2specie, center_nodes, center_cons, species_info, next_species_key -@jit def argmin_with_mask(arr: Array, mask: Array) -> Array: masked_arr = jnp.where(mask, arr, jnp.inf) min_idx = jnp.argmin(masked_arr) diff --git a/configs/configer.py b/configs/configer.py index d226eb6..79f60b3 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -4,7 +4,8 @@ import configparser import numpy as np -from algorithms.neat import act_name2func, agg_name2func +from algorithms.neat.genome.activations import act_name2func +from algorithms.neat.genome.aggregations import agg_name2func # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ diff --git a/configs/default_config.ini b/configs/default_config.ini index 3533f65..5378872 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -1,19 +1,18 @@ [basic] num_inputs = 2 num_outputs = 1 -init_maximum_nodes = 50 -init_maximum_connections = 50 -init_maximum_species = 10 -expand_coe = 1.5 -pre_expand_threshold = 0.75 +maximum_nodes = 50 +maximum_connections = 50 +maximum_species = 10 forward_way = "pop" batch_size = 4 +random_seed = 0 [population] -fitness_threshold = 100000 +fitness_threshold = 3.99999 generation_limit = 1000 fitness_criterion = "max" -pop_size = 50 +pop_size = 100000 [genome] compatibility_disjoint = 1.0 diff --git a/examples/debug.py b/examples/debug.py index aefc4ad..1a9f14a 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica return evaluate(func) - - def equal(ar1, ar2): if ar1.shape != ar2.shape: return False diff --git a/examples/xor.ini b/examples/xor.ini index 233ace7..893fff7 100644 --- a/examples/xor.ini +++ b/examples/xor.ini @@ -2,4 +2,4 @@ forward_way = "common" [population] -fitness_threshold = 3.9999 \ No newline at end of file +fitness_threshold = 4 \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 228978a..1d88081 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -2,7 +2,6 @@ import jax import numpy as np from configs import Configer -from algorithms.neat import Genome from pipeline import Pipeline xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) @@ -22,10 +21,10 @@ def evaluate(forward_func): def main(): config = Configer.load_config("xor.ini") - pipeline = Pipeline(config, seed=6) + pipeline = Pipeline(config) nodes, cons = pipeline.auto_run(evaluate) - g = Genome(nodes, cons, config) - print(g) + # g = Genome(nodes, cons, config) + # print(g) if __name__ == '__main__': diff --git a/examples/xor3d.ini b/examples/xor3d.ini new file mode 100644 index 0000000..85b41af --- /dev/null +++ b/examples/xor3d.ini @@ -0,0 +1,47 @@ +[basic] +num_inputs = 3 +num_outputs = 1 +maximum_nodes = 50 +maximum_connections = 50 +maximum_species = 10 +forward_way = "common" +batch_size = 4 +random_seed = 42 + +[population] +fitness_threshold = 8 +generation_limit = 1000 +fitness_criterion = "max" +pop_size = 100000 + +[genome] +compatibility_disjoint = 1.0 +compatibility_weight = 0.5 +conn_add_prob = 0.4 +conn_add_trials = 1 +conn_delete_prob = 0 +node_add_prob = 0.2 +node_delete_prob = 0 + +[species] +compatibility_threshold = 3 +species_elitism = 1 +max_stagnation = 15 +genome_elitism = 2 +survival_threshold = 0.2 +min_species_size = 1 +spawn_number_move_rate = 0.5 + +[gene-bias] +bias_init_mean = 0.0 +bias_init_std = 1.0 +bias_mutate_power = 0.5 +bias_mutate_rate = 0.7 +bias_replace_rate = 0.1 + +[gene-weight] +weight_init_mean = 0.0 +weight_init_std = 1.0 +weight_mutate_power = 0.5 +weight_mutate_rate = 0.8 +weight_replace_rate = 0.1 diff --git a/examples/xor3d.py b/examples/xor3d.py new file mode 100644 index 0000000..1fa7c43 --- /dev/null +++ b/examples/xor3d.py @@ -0,0 +1,31 @@ +import jax +import numpy as np + +from configs import Configer +from pipeline import Pipeline + +xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +def main(): + config = Configer.load_config("xor3d.ini") + pipeline = Pipeline(config) + nodes, cons = pipeline.auto_run(evaluate) + # g = Genome(nodes, cons, config) + # print(g) + + +if __name__ == '__main__': + main() diff --git a/pipeline.py b/pipeline.py index 71a83b5..24782d7 100644 --- a/pipeline.py +++ b/pipeline.py @@ -5,8 +5,8 @@ import numpy as np import jax from jax import jit, vmap -from configs import Configer from algorithms import neat +from configs.configer import Configer class Pipeline: @@ -14,58 +14,40 @@ class Pipeline: Neat algorithm pipeline. """ - def __init__(self, config, seed=42): - self.randkey = jax.random.PRNGKey(seed) - np.random.seed(seed) - + def __init__(self, config): self.config = config # global config - self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions + self.jit_config = Configer.create_jit_config(config) - self.P = config['pop_size'] - self.N = config['init_maximum_nodes'] - self.C = config['init_maximum_connections'] - self.S = config['init_maximum_species'] - - self.generation = 0 self.best_genome = None - self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config) - self.species_info = np.full((self.S, 4), np.nan) - self.species_info[0, :] = 0, -np.inf, 0, self.P - self.idx2species = np.zeros(self.P, dtype=np.float32) - self.center_nodes = np.full((self.S, self.N, 5), np.nan) - self.center_cons = np.full((self.S, self.C, 4), np.nan) - self.center_nodes[0, :, :] = self.pop_nodes[0, :, :] - self.center_cons[0, :, :] = self.pop_cons[0, :, :] + self.neat_states = neat.initialize(config) self.best_fitness = float('-inf') self.generation_timestamp = time.time() self.evaluate_time = 0 + + self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ + self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config) + + + self.forward = neat.create_forward_function(config) self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections)) self.pop_topological_sort = jit(vmap(neat.topological_sort)) - self.forward = neat.create_forward_function(config) - # fitness_lower = np.zeros(self.P, dtype=np.float32) - # randkey_lower = np.zeros(2, dtype=np.uint32) - # pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32) - # pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32) - # species_info_lower = np.zeros((self.S, 4), dtype=np.float32) - # idx2species_lower = np.zeros(self.P, dtype=np.float32) - # center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32) - # center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32) - # - # self.tell_func = jit(neat.tell).lower(fitness_lower, - # randkey_lower, - # pop_nodes_lower, - # pop_cons_lower, - # species_info_lower, - # idx2species_lower, - # center_nodes_lower, - # center_cons_lower, - # 0, - # self.jit_config).compile() + # self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32), + # self.randkey, + # self.pop_nodes, + # self.pop_cons, + # self.species_info, + # self.idx2species, + # self.center_nodes, + # self.center_cons, + # self.generation, + # self.next_node_key, + # self.next_species_key, + # self.jit_config).compile() def ask(self): """ @@ -97,9 +79,19 @@ class Pipeline: def tell(self, fitness): self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ - self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons, - self.species_info, self.idx2species, self.center_nodes, - self.center_cons, self.generation, self.jit_config) + self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness, + self.randkey, + self.pop_nodes, + self.pop_cons, + self.species_info, + self.idx2species, + self.center_nodes, + self.center_cons, + self.generation, + self.next_node_key, + self.next_species_key, + self.jit_config) + def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']): @@ -109,7 +101,7 @@ class Pipeline: fitnesses = fitness_func(forward_func) self.evaluate_time += time.time() - tic - assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" + # assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" if analysis is not None: if analysis == "default": @@ -138,7 +130,8 @@ class Pipeline: self.best_fitness = fitnesses[max_idx] self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) - species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0] + member_count = jax.device_get(self.species_info[:, 3]) + species_sizes = [int(i) for i in member_count if i > 0] print(f"Generation: {self.generation}", f"species: {len(species_sizes)}, {species_sizes}",