From 72c9d4167a256795b3a1e7d5a497c834d458fff3 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 13 May 2023 20:58:03 +0800 Subject: [PATCH] FAST! --- algorithms/neat/function_factory.py | 249 ++++++++++------------------ algorithms/neat/genome/mutate.py | 66 ++------ algorithms/neat/jitable_speciate.py | 109 ------------ algorithms/neat/pipeline.py | 98 +++-------- algorithms/neat/population.py | 168 +++++++++++++++++++ algorithms/neat/species.py | 139 ++++------------ examples/jax_playground.py | 21 ++- examples/jitable_speciate_t.py | 36 ++-- examples/xor.py | 2 +- utils/default_config.json | 13 +- 10 files changed, 372 insertions(+), 529 deletions(-) delete mode 100644 algorithms/neat/jitable_speciate.py create mode 100644 algorithms/neat/population.py diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index f68194d..5a8c257 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -2,30 +2,34 @@ Lowers, compiles, and creates functions used in the NEAT pipeline. """ from functools import partial +import time -import jax.random import numpy as np from jax import jit, vmap -from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover +from .genome import act_name2key, agg_name2key, initialize_genomes from .genome import topological_sort, forward_single, unflatten_connections +from .population import create_next_generation_then_speciate class FunctionFactory: - def __init__(self, config, debug=False): + def __init__(self, config): self.config = config - self.debug = debug - self.init_N = config.basic.init_maximum_nodes - self.init_C = config.basic.init_maximum_connections self.expand_coe = config.basic.expands_coe self.precompile_times = config.basic.pre_compile_times self.compiled_function = {} + self.time_cost = {} self.load_config_vals(config) - self.precompile() + + self.create_topological_sort_with_args() + self.create_single_forward_with_args() + self.create_update_speciate_with_args() def load_config_vals(self, config): + self.compatibility_threshold = self.config.neat.species.compatibility_threshold + self.problem_batch = config.basic.problem_batch self.pop_size = config.neat.population.pop_size @@ -79,12 +83,12 @@ class FunctionFactory: self.delete_connection_rate = genome.conn_delete_prob self.single_structure_mutate = genome.single_structural_mutation - def create_initialize(self): + def create_initialize(self, N, C): func = partial( initialize_genomes, pop_size=self.pop_size, - N=self.init_N, - C=self.init_C, + N=N, + C=C, num_inputs=self.num_inputs, num_outputs=self.num_outputs, default_bias=self.bias_mean, @@ -93,166 +97,85 @@ class FunctionFactory: default_agg=self.agg_default, default_weight=self.weight_mean ) - if self.debug: - def debug_initialize(*args): - return func(*args) + return func - return debug_initialize - else: - return func + def create_update_speciate_with_args(self): + species_kwargs = { + "disjoint_coe": self.disjoint_coe, + "compatibility_coe": self.compatibility_coe, + "compatibility_threshold": self.compatibility_threshold + } - def precompile(self): - self.create_mutate_with_args() - self.create_distance_with_args() - self.create_crossover_with_args() - self.create_topological_sort_with_args() - self.create_single_forward_with_args() - # - # n, c = self.init_N, self.init_C - # print("start precompile") - # for _ in range(self.precompile_times): - # self.compile_mutate(n) - # self.compile_distance(n) - # self.compile_crossover(n) - # self.compile_topological_sort_batch(n) - # self.compile_pop_batch_forward(n) - # n = int(self.expand_coe * n) - # - # # precompile other functions used in jax - # key = jax.random.PRNGKey(0) - # _ = jax.random.split(key, 3) - # _ = jax.random.split(key, self.pop_size * 2) - # _ = jax.random.split(key, self.pop_size) - # - # print("end precompile") + mutate_kwargs = { + "input_idx": self.input_idx, + "output_idx": self.output_idx, + "bias_mean": self.bias_mean, + "bias_std": self.bias_std, + "bias_mutate_strength": self.bias_mutate_strength, + "bias_mutate_rate": self.bias_mutate_rate, + "bias_replace_rate": self.bias_replace_rate, + "response_mean": self.response_mean, + "response_std": self.response_std, + "response_mutate_strength": self.response_mutate_strength, + "response_mutate_rate": self.response_mutate_rate, + "response_replace_rate": self.response_replace_rate, + "weight_mean": self.weight_mean, + "weight_std": self.weight_std, + "weight_mutate_strength": self.weight_mutate_strength, + "weight_mutate_rate": self.weight_mutate_rate, + "weight_replace_rate": self.weight_replace_rate, + "act_default": self.act_default, + "act_list": self.act_list, + "act_replace_rate": self.act_replace_rate, + "agg_default": self.agg_default, + "agg_list": self.agg_list, + "agg_replace_rate": self.agg_replace_rate, + "enabled_reverse_rate": self.enabled_reverse_rate, + "add_node_rate": self.add_node_rate, + "delete_node_rate": self.delete_node_rate, + "add_connection_rate": self.add_connection_rate, + "delete_connection_rate": self.delete_connection_rate, + } - def create_mutate_with_args(self): - func = partial( - mutate, - input_idx=self.input_idx, - output_idx=self.output_idx, - bias_mean=self.bias_mean, - bias_std=self.bias_std, - bias_mutate_strength=self.bias_mutate_strength, - bias_mutate_rate=self.bias_mutate_rate, - bias_replace_rate=self.bias_replace_rate, - response_mean=self.response_mean, - response_std=self.response_std, - response_mutate_strength=self.response_mutate_strength, - response_mutate_rate=self.response_mutate_rate, - response_replace_rate=self.response_replace_rate, - weight_mean=self.weight_mean, - weight_std=self.weight_std, - weight_mutate_strength=self.weight_mutate_strength, - weight_mutate_rate=self.weight_mutate_rate, - weight_replace_rate=self.weight_replace_rate, - act_default=self.act_default, - act_list=self.act_list, - act_replace_rate=self.act_replace_rate, - agg_default=self.agg_default, - agg_list=self.agg_list, - agg_replace_rate=self.agg_replace_rate, - enabled_reverse_rate=self.enabled_reverse_rate, - add_node_rate=self.add_node_rate, - delete_node_rate=self.delete_node_rate, - add_connection_rate=self.add_connection_rate, - delete_connection_rate=self.delete_connection_rate, - single_structure_mutate=self.single_structure_mutate + self.update_speciate_with_args = partial( + create_next_generation_then_speciate, + species_kwargs=species_kwargs, + mutate_kwargs=mutate_kwargs ) - self.mutate_with_args = func - def compile_mutate(self, n, c): - func = self.mutate_with_args - rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) - nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, c, 4)) - new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32) - batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower, - connections_lower, new_node_key_lower).compile() - self.compiled_function[('mutate', n, c)] = batched_mutate_func - - def create_mutate(self, n, c): - key = ('mutate', n, c) + def create_update_speciate(self, N, C, S): + key = ("update_speciate", N, C, S) if key not in self.compiled_function: - self.compile_mutate(n, c) - if self.debug: - def debug_mutate(*args): - res_nodes, res_connections = self.compiled_function[key](*args) - return res_nodes.block_until_ready(), res_connections.block_until_ready() + self.compile_update_speciate(N, C, S) + return self.compiled_function[key] - return debug_mutate - else: - return self.compiled_function[key] - - def create_distance_with_args(self): - func = partial( - distance, - disjoint_coe=self.disjoint_coe, - compatibility_coe=self.compatibility_coe - ) - self.distance_with_args = func - - def compile_distance(self, n, c): - func = self.distance_with_args - o2o_nodes1_lower = np.zeros((n, 5)) - o2o_connections1_lower = np.zeros((c, 4)) - o2o_nodes2_lower = np.zeros((n, 5)) - o2o_connections2_lower = np.zeros((c, 4)) - o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, - o2o_nodes2_lower, o2o_connections2_lower).compile() - - o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) - o2m_connections2_lower = np.zeros((self.pop_size, c, 4)) - o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, - o2m_nodes2_lower, - o2m_connections2_lower).compile() - - self.compiled_function[('o2o_distance', n, c)] = o2o_distance - self.compiled_function[('o2m_distance', n, c)] = o2m_distance - - def create_distance(self, n, c): - key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c) - if key1 not in self.compiled_function: - self.compile_distance(n, c) - if self.debug: - - def debug_o2o_distance(*args): - return self.compiled_function[key1](*args).block_until_ready() - - def debug_o2m_distance(*args): - return self.compiled_function[key2](*args).block_until_ready() - - return debug_o2o_distance, debug_o2m_distance - else: - return self.compiled_function[key1], self.compiled_function[key2] - - def create_crossover_with_args(self): - self.crossover_with_args = crossover - - def compile_crossover(self, n, c): - func = self.crossover_with_args - randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) - nodes1_lower = np.zeros((self.pop_size, n, 5)) - connections1_lower = np.zeros((self.pop_size, c, 4)) - nodes2_lower = np.zeros((self.pop_size, n, 5)) - connections2_lower = np.zeros((self.pop_size, c, 4)) - func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() - self.compiled_function[('crossover', n, c)] = func - - def create_crossover(self, n, c): - key = ('crossover', n, c) - if key not in self.compiled_function: - self.compile_crossover(n, c) - if self.debug: - - def debug_crossover(*args): - res_nodes, res_connections = self.compiled_function[key](*args) - return res_nodes.block_until_ready(), res_connections.block_until_ready() - - return debug_crossover - else: - return self.compiled_function[key] + def compile_update_speciate(self, N, C, S): + func = self.update_speciate_with_args + randkey_lower = np.zeros((2,), dtype=np.uint32) + pop_nodes_lower = np.zeros((self.pop_size, N, 5)) + pop_cons_lower = np.zeros((self.pop_size, C, 4)) + winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32) + loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32) + elite_mask_lower = np.zeros((self.pop_size,), dtype=bool) + new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32) + pre_spe_center_nodes_lower = np.zeros((S, N, 5)) + pre_spe_center_cons_lower = np.zeros((S, C, 4)) + species_keys = np.zeros((S,), dtype=np.int32) + new_species_keys_lower = 0 + compiled_func = jit(func).lower( + randkey_lower, + pop_nodes_lower, + pop_cons_lower, + winner_part_lower, + loser_part_lower, + elite_mask_lower, + new_node_keys_start_lower, + pre_spe_center_nodes_lower, + pre_spe_center_cons_lower, + species_keys, + new_species_keys_lower, + ).compile() + self.compiled_function[("update_speciate", N, C, S)] = compiled_func def create_topological_sort_with_args(self): self.topological_sort_with_args = topological_sort diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index dd155be..88c56ce 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_ from .graph import check_cycles -@partial(jit, static_argnames=('single_structure_mutate',)) +# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible. +@jit def mutate(rand_key: Array, nodes: Array, connections: Array, @@ -44,7 +45,7 @@ def mutate(rand_key: Array, delete_node_rate: float = 0.2, add_connection_rate: float = 0.4, delete_connection_rate: float = 0.4, - single_structure_mutate: bool = True): + ): """ :param output_idx: :param input_idx: @@ -78,65 +79,26 @@ def mutate(rand_key: Array, :param delete_node_rate: :param add_connection_rate: :param delete_connection_rate: - :param single_structure_mutate: a genome is structurally mutate at most once :return: """ - # mutate_structure - def nothing(rk, n, c): - return n, c - def m_add_node(rk, n, c): return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default) - def m_delete_node(rk, n, c): - return mutate_delete_node(rk, n, c, input_idx, output_idx) - def m_add_connection(rk, n, c): return mutate_add_connection(rk, n, c, input_idx, output_idx) - def m_delete_connection(rk, n, c): - return mutate_delete_connection(rk, n, c) + r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) - mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection] + # mutate add node + aux_nodes, aux_connections = m_add_node(r1, nodes, connections) + nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes) + connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections) - if single_structure_mutate: - r1, r2, rand_key = jax.random.split(rand_key, 3) - d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate) - - # shorten variable names for beauty - anr, dnr = add_node_rate / d, delete_node_rate / d - acr, dcr = add_connection_rate / d, delete_connection_rate / d - - r = rand(r1) - branch = 0 - branch = jnp.where(r <= anr, 1, branch) - branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch) - branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch) - branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch) - nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections)) - else: - r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) - - # mutate add node - aux_nodes, aux_connections = m_add_node(r1, nodes, connections) - nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes) - connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections) - - # mutate delete node - aux_nodes, aux_connections = m_delete_node(r2, nodes, connections) - nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes) - connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections) - - # mutate add connection - aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) - nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) - connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) - - # mutate delete connection - aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections) - nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes) - connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections) + # mutate add connection + aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) + nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) + connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, bias_replace_rate, response_mean, response_std, @@ -379,9 +341,9 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, # randomly choose two nodes k1, k2 = jax.random.split(rand_key, num=2) i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, - allow_input_keys=True, allow_output_keys=True) + allow_input_keys=True, allow_output_keys=True) o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, - allow_input_keys=False, allow_output_keys=True) + allow_input_keys=False, allow_output_keys=True) con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) diff --git a/algorithms/neat/jitable_speciate.py b/algorithms/neat/jitable_speciate.py deleted file mode 100644 index 8f2192e..0000000 --- a/algorithms/neat/jitable_speciate.py +++ /dev/null @@ -1,109 +0,0 @@ -from functools import partial - -import jax -import jax.numpy as jnp -from jax import jit, vmap - -from jax import Array - -from .genome import distance -from .genome.utils import I_INT, fetch_first, argmin_with_mask - - -@jit -def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array, - disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0 - ): - """ - args: - pop_nodes: (pop_size, N, 5) - pop_cons: (pop_size, C, 4) - spe_center_nodes: (species_size, N, 5) - spe_center_cons: (species_size, C, 4) - """ - pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0] - - # prepare distance functions - distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe) - o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) - s2p_distance_func = vmap( - o2p_distance_func, in_axes=(0, 0, None, None) - ) - - idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species - - # the distance between each species' center and each genome in population - s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons) - - def continue_execute_while(carry): - i, i2s, scn, scc = carry - not_all_assigned = ~jnp.all(i2s != I_INT) - not_reach_species_upper_bounds = i < species_size - return not_all_assigned & not_reach_species_upper_bounds - - def deal_with_each_center_genome(carry): - i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i] - - i2s, scn, scc = jax.lax.cond( - jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid - create_new_specie, # if not valid, create a new specie - update_exist_specie, # if valid, update the specie - (i, i2s, scn, scc) - ) - - return i + 1, i2s, scn, scc - - def create_new_specie(carry): - i, i2s, scn, scc = carry - # pick the first one who has not been assigned to any species - idx = fetch_first(i2s == I_INT) - - # assign it to new specie - i2s = i2s.at[idx].set(i) - - # update center genomes - scn = scn.at[i].set(pop_nodes[idx]) - scc = scc.at[i].set(pop_cons[idx]) - - i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc)) - return i2s, scn, scc - - def update_exist_specie(carry): - i, i2s, scn, scc = carry - - # find new center - idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) - - # update new center - i2s = i2s.at[idx].set(i) - - # update center genomes - scn = scn.at[i].set(pop_nodes[idx]) - scc = scc.at[i].set(pop_cons[idx]) - - i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc)) - return i2s, scn, scc - - def speciate_by_threshold(carry): - i, i2s, scn, scc = carry - # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons) - close_enough_mask = o2p_distance < compatibility_threshold - - # when it is close enough, assign it to the species, remember not to update genome has already been assigned - i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s) - return i2s, scn, scc - - # update idx2specie - _, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop( - continue_execute_while, - deal_with_each_center_genome, - (0, idx2specie, spe_center_nodes, spe_center_cons) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition seems to be only happened when the number of species is reached species upper bounds - idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie) - - return idx2specie, spe_center_nodes, spe_center_cons diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index f199cb7..27c0083 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -7,8 +7,9 @@ import numpy as np from .species import SpeciesController from .genome import expand, expand_single from .function_factory import FunctionFactory -from .genome.genome import count -from .genome.debug.tools import check_array_valid + +from .population import * + class Pipeline: """ @@ -17,7 +18,7 @@ class Pipeline: def __init__(self, config, seed=42): self.time_dict = {} - self.function_factory = FunctionFactory(config, debug=True) + self.function_factory = FunctionFactory(config) self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) @@ -25,17 +26,18 @@ class Pipeline: self.config = config self.N = config.basic.init_maximum_nodes self.C = config.basic.init_maximum_connections + self.S = config.basic.init_maximum_species self.expand_coe = config.basic.expands_coe self.pop_size = config.neat.population.pop_size self.species_controller = SpeciesController(config) self.initialize_func = self.function_factory.create_initialize() - self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() + self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func() - self.compile_functions(debug=True) + self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S) self.generation = 0 - self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) + self.species_controller.init_speciate(self.pop_nodes, self.pop_cons) self.best_fitness = float('-inf') self.best_genome = None @@ -47,22 +49,26 @@ class Pipeline: :return: Algorithm gives the population a forward function, then environment gives back the fitnesses. """ - return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections) + return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons) def tell(self, fitnesses): self.generation += 1 - self.species_controller.update_species_fitnesses(fitnesses) + winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask( + fitnesses, + self.generation, + self.S, self.N, self.C) - winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation) + new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) + self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate( + self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask, + new_node_keys, + pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start) - self.update_next_generation(winner_part, loser_part, elite_mask) + idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys]) - # pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) - - self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation, - self.o2o_distance, self.o2m_distance) + self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation) self.expand() @@ -86,49 +92,6 @@ class Pipeline: print("Generation limit reached!") return self.best_genome - def update_next_generation(self, winner_part, loser_part, elite_mask) -> None: - """ - create next generation - :param winner_part: - :param loser_part: - :param elite_mask: - :return: - """ - - assert self.pop_nodes.shape[0] == self.pop_size - k1, k2, self.randkey = jax.random.split(self.randkey, 3) - - crossover_rand_keys = jax.random.split(k1, self.pop_size) - mutate_rand_keys = jax.random.split(k2, self.pop_size) - - # batch crossover - wpn = self.pop_nodes[winner_part] # winner pop nodes - wpc = self.pop_connections[winner_part] # winner pop connections - lpn = self.pop_nodes[loser_part] # loser pop nodes - lpc = self.pop_connections[loser_part] # loser pop connections - - npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, - lpc) # new pop nodes, new pop connections - - # for i in range(self.pop_size): - # n, c = np.array(npn[i]), np.array(npc[i]) - # check_array_valid(n, c, self.input_idx, self.output_idx) - - # mutate - new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) - - m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes - - # for i in range(self.pop_size): - # n, c = np.array(m_npn[i]), np.array(m_npc[i]) - # check_array_valid(n, c, self.input_idx, self.output_idx) - - # elitism don't mutate - npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) - - self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn) - self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc) - def expand(self): """ Expand the population if needed. @@ -142,37 +105,28 @@ class Pipeline: if max_node_size >= self.N: self.N = int(self.N * self.expand_coe) print(f"node expand to {self.N}!") - self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) + self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) # don't forget to expand representation genome in species for s in self.species_controller.species.values(): s.representative = expand_single(*s.representative, self.N, self.C) - # update functions - self.compile_functions(debug=True) - - pop_con_keys = self.pop_connections[:, :, 0] + pop_con_keys = self.pop_cons[:, :, 0] pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) max_con_size = np.max(pop_node_sizes) if max_con_size >= self.C: self.C = int(self.C * self.expand_coe) print(f"connections expand to {self.C}!") - self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) + self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C) # don't forget to expand representation genome in species for s in self.species_controller.species.values(): s.representative = expand_single(*s.representative, self.N, self.C) - # update functions - self.compile_functions(debug=True) + self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S) - - - def compile_functions(self, debug=False): - self.mutate_func = self.function_factory.create_mutate(self.N, self.C) - self.crossover_func = self.function_factory.create_crossover(self.N, self.C) - self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C) + def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) @@ -185,7 +139,7 @@ class Pipeline: max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] - self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx]) + self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) print(f"Generation: {self.generation}", f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py new file mode 100644 index 0000000..9ad0bb6 --- /dev/null +++ b/algorithms/neat/population.py @@ -0,0 +1,168 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from jax import jit, vmap + +from jax import Array + +from .genome import distance, mutate, crossover +from .genome.utils import I_INT, fetch_first, argmin_with_mask + + +@jit +def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, + new_node_keys, + pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start, + species_kwargs, mutate_kwargs): + # create next generation + pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, + new_node_keys, **mutate_kwargs) + + # speciate + idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes, + pre_spe_center_cons, species_keys, + new_species_key_start, **species_kwargs) + + return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys + + +@jit +def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array, + species_keys, new_species_key_start, + disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0 + ): + """ + args: + pop_nodes: (pop_size, N, 5) + pop_cons: (pop_size, C, 4) + spe_center_nodes: (species_size, N, 5) + spe_center_cons: (species_size, C, 4) + """ + pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0] + + # prepare distance functions + distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe) + o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) + s2p_distance_func = vmap( + o2p_distance_func, in_axes=(0, 0, None, None) + ) + + # idx to specie key + idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species + + # part 1: find new centers + # the distance between each species' center and each genome in population + s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons) + + def find_new_centers(i, carry): + i2s, scn, scc = carry + # find new center + idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) + + # check species[i] exist or not + # if not exist, set idx and i to I_INT, jax will not do array value assignment + idx = jnp.where(species_keys[i] != I_INT, idx, I_INT) + i = jnp.where(species_keys[i] != I_INT, i, I_INT) + + i2s = i2s.at[idx].set(species_keys[i]) + scn = scn.at[i].set(pop_nodes[idx]) + scc = scc.at[i].set(pop_cons[idx]) + return i2s, scn, scc + + idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons)) + + def continue_execute_while(carry): + i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key + not_all_assigned = ~jnp.all(i2s != I_INT) + not_reach_species_upper_bounds = i < species_size + return not_all_assigned & not_reach_species_upper_bounds + + def deal_with_each_center_genome(carry): + i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i] + + i2s, scn, scc, sk, ck = jax.lax.cond( + jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid + create_new_specie, # if not valid, create a new specie + update_exist_specie, # if valid, update the specie + (i, i2s, scn, scc, sk, ck) + ) + + return i + 1, i2s, scn, scc, sk, ck + + def create_new_specie(carry): + i, i2s, scn, scc, sk, ck = carry + # pick the first one who has not been assigned to any species + idx = fetch_first(i2s == I_INT) + + # assign it to new specie + sk = sk.at[i].set(ck) + i2s = i2s.at[idx].set(ck) + + # update center genomes + scn = scn.at[i].set(pop_nodes[idx]) + scc = scc.at[i].set(pop_cons[idx]) + + i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk)) + return i2s, scn, scc, sk, ck + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, scn, scc, sk, ck = carry + + i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk)) + return i2s, scn, scc, sk, ck + + def speciate_by_threshold(carry): + i, i2s, scn, scc, sk = carry + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons) + close_enough_mask = o2p_distance < compatibility_threshold + + # when it is close enough, assign it to the species, remember not to update genome has already been assigned + i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s) + return i2s, scn, scc + + current_new_key = new_species_key_start + + # update idx2specie + _, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop( + continue_execute_while, + deal_with_each_center_genome, + (0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition seems to be only happened when the number of species is reached species upper bounds + idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie) + + return idx2specie, spe_center_nodes, spe_center_cons, species_keys + + +@jit +def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys, + **mutate_kwargs): + # prepare functions + batch_crossover = vmap(crossover) + mutate_with_args = vmap(partial(mutate, **mutate_kwargs)) + + pop_size = pop_nodes.shape[0] + k1, k2 = jax.random.split(rand_key, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn = pop_nodes[winner_part] # winner pop nodes + wpc = pop_cons[winner_part] # winner pop connections + lpn = pop_nodes[loser_part] # loser pop nodes + lpc = pop_cons[loser_part] # loser pop connections + + npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + + m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) + pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) + + return pop_nodes, pop_cons diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index d9f2079..ff3e015 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -5,6 +5,8 @@ import jax import numpy as np from numpy.typing import NDArray +from .genome.utils import I_INT + class Species(object): @@ -12,7 +14,7 @@ class Species(object): self.key = key self.created = generation self.last_improved = generation - self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections) + self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections) self.members: NDArray = None # idx in pop_nodes, pop_connections, self.fitness = None self.member_fitnesses = None @@ -34,7 +36,7 @@ class SpeciesController: def __init__(self, config): self.config = config - self.compatibility_threshold = self.config.neat.species.compatibility_threshold + self.species_elitism = self.config.neat.species.species_elitism self.pop_size = self.config.neat.population.pop_size self.max_stagnation = self.config.neat.species.max_stagnation @@ -59,97 +61,7 @@ class SpeciesController: s.update((pop_nodes[0], pop_connections[0]), members) self.species[species_id] = s - def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int, - o2o_distance: Callable, o2m_distance: Callable) -> None: - """ - :param pop_nodes: - :param pop_connections: - :param generation: use to flag the created time of new species - :param o2o_distance: distance function for one-to-one comparison - :param o2m_distance: distance function for one-to-many comparison - :return: - """ - unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool) - previous_species_list = list(self.species.keys()) - - # Find the best representatives for each existing species. - new_representatives = {} - new_members = {} - - total_distances = jax.device_get([ - o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections) - for sid in previous_species_list - ]) - - # TODO: Use jit to wrapper function find_min_with_mask to accelerate this part - for i, sid in enumerate(previous_species_list): - distances = total_distances[i] - min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance - - new_representatives[sid] = min_idx - new_members[sid] = [min_idx] - unspeciated[min_idx] = False - - # Partition population into species based on genetic similarity. - - # First, fast match the population to previous species - if previous_species_list: # exist previous species - rid_list = [new_representatives[sid] for sid in previous_species_list] - res_pop_distance = jax.device_get([ - o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) - for rid in rid_list - ]) - - pop_res_distance = np.stack(res_pop_distance, axis=0).T - for i in range(pop_res_distance.shape[0]): - if not unspeciated[i]: - continue - min_idx = np.argmin(pop_res_distance[i]) - min_val = pop_res_distance[i, min_idx] - if min_val <= self.compatibility_threshold: - species_id = previous_species_list[min_idx] - new_members[species_id].append(i) - unspeciated[i] = False - - # Second, slowly match the lonely population to new-created species.s - # lonely genome is proved to be not compatible with any previous species, so they only need to be compared with - # the new representatives. - for i in range(pop_nodes.shape[0]): - if not unspeciated[i]: - continue - unspeciated[i] = False - if len(new_representatives) != 0: - # the representatives of new species - sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) - distances = jax.device_get([ - o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) - for r in rid - ]) - distances = np.array(distances) - min_idx = np.argmin(distances) - min_val = distances[min_idx] - if min_val <= self.compatibility_threshold: - species_id = sid[min_idx] - new_members[species_id].append(i) - continue - # create a new species - species_id = next(self.species_idxer) - new_representatives[species_id] = i - new_members[species_id] = [i] - - assert np.all(~unspeciated) - - # Update species collection based on new speciation. - for sid, rid in new_representatives.items(): - s = self.species.get(sid) - if s is None: - s = Species(sid, generation) - self.species[sid] = s - - members = np.array(new_members[sid]) - s.update((pop_nodes[rid], pop_connections[rid]), members) - - def update_species_fitnesses(self, fitnesses): + def __update_species_fitnesses(self, fitnesses): """ update the fitness of each species :param fitnesses: @@ -163,7 +75,7 @@ class SpeciesController: s.fitness_history.append(s.fitness) s.adjusted_fitness = None - def stagnation(self, generation): + def __stagnation(self, generation): """ code modified from neat-python! :param generation: @@ -196,7 +108,7 @@ class SpeciesController: result.append((sid, s, is_stagnant)) return result - def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: + def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: """ code modified from neat-python! :param fitnesses: @@ -215,7 +127,7 @@ class SpeciesController: max_fitness = -np.inf remaining_species = [] - for stag_sid, stag_s, stagnant in self.stagnation(generation): + for stag_sid, stag_s, stagnant in self.__stagnation(generation): if not stagnant: min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses)) max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses)) @@ -285,6 +197,33 @@ class SpeciesController: return winner_part, loser_part, np.array(elite_mask) + def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation): + for idx, key in enumerate(species_keys): + if key == I_INT: + continue + + members = np.where(idx2specie == key)[0] + assert len(members) > 0 + if key not in self.species: + s = Species(key, generation) + self.species[key] = s + + self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members) + + def ask(self, fitnesses, generation, S, N, C): + self.__update_species_fitnesses(fitnesses) + winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation) + pre_spe_center_nodes = np.full((S, N, 5), np.nan) + pre_spe_center_cons = np.full((S, C, 4), np.nan) + species_keys = np.full((S,), I_INT) + for idx, (key, specie) in enumerate(self.species.items()): + pre_spe_center_nodes[idx] = specie.representative[0] + pre_spe_center_cons[idx] = specie.representative[1] + species_keys[idx] = key + next_new_specie_key = max(self.species.keys()) + 1 + return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \ + pre_spe_center_cons, species_keys, next_new_specie_key + def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): """ @@ -326,13 +265,7 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): return spawn_amounts -def find_min_with_mask(arr: NDArray, mask: NDArray) -> int: - masked_arr = np.where(mask, arr, np.inf) - min_idx = np.argmin(masked_arr) - return min_idx - - def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \ -> Tuple[NDArray, NDArray]: sorted_idx = np.argsort(fitnesses)[::-1] - return members[sorted_idx], fitnesses[sorted_idx] \ No newline at end of file + return members[sorted_idx], fitnesses[sorted_idx] diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 63df18f..e3e17b9 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -4,9 +4,10 @@ from jax import jit, vmap from time_utils import using_cprofile from time import time # +import numpy as np @jit -def fx(x, y): - return x + y +def fx(x): + return jnp.arange(x, x + 10) # # # # @jit @@ -33,13 +34,15 @@ def fx(x, y): # @using_cprofile def main(): - vmap_f = vmap(fx, in_axes=(None, 0)) - vmap_vmap_f = vmap(vmap_f, in_axes=(0, None)) - a = jnp.array([20,10,30]) - b = jnp.array([6, 5, 4]) - res = vmap_vmap_f(a, b) - print(res) - print(jnp.argmin(res, axis=1)) + print(fx(1)) + + # vmap_f = vmap(fx, in_axes=(None, 0)) + # vmap_vmap_f = vmap(vmap_f, in_axes=(0, None)) + # a = jnp.array([20,10,30]) + # b = jnp.array([6, 5, 4]) + # res = vmap_vmap_f(a, b) + # print(res) + # print(jnp.argmin(res, axis=1)) diff --git a/examples/jitable_speciate_t.py b/examples/jitable_speciate_t.py index 8629665..9eae841 100644 --- a/examples/jitable_speciate_t.py +++ b/examples/jitable_speciate_t.py @@ -4,7 +4,7 @@ import numpy as np from algorithms.neat.function_factory import FunctionFactory from algorithms.neat.genome.debug.tools import check_array_valid from utils import Configer -from algorithms.neat.jitable_speciate import jitable_speciate +from algorithms.neat.population import speciate from algorithms.neat.genome.crossover import crossover from algorithms.neat.genome.utils import I_INT from time import time @@ -23,7 +23,9 @@ if __name__ == '__main__': spe_center_connections = np.full((species_size, C, 4), np.nan) spe_center_nodes[0] = pop_nodes[0] spe_center_connections[0] = pop_connections[0] - + spe_keys = np.full((species_size,), I_INT) + spe_keys[0] = 0 + new_spe_key = 1 key = jax.random.PRNGKey(0) new_node_idx = 100 @@ -43,25 +45,31 @@ if __name__ == '__main__': n1, c1 = pop_nodes[idx1], pop_connections[idx1] n2, c2 = pop_nodes[idx2], pop_connections[idx2] crossover_keys = jax.random.split(subkey, len(pop_nodes)) + pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2) + # for i in range(len(pop_nodes)): # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) #speciate next generation - idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections, - compatibility_threshold=2.5) + idx2specie, spe_center_nodes, spe_center_cons, spe_keys, new_spe_key = speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections, + spe_keys, new_spe_key, + compatibility_threshold=3) - idx2specie = np.array(idx2specie) - spe_dict = {} - for i in range(len(idx2specie)): - spe_idx = idx2specie[i] - if spe_idx not in spe_dict: - spe_dict[spe_idx] = 1 - else: - spe_dict[spe_idx] += 1 + print(spe_keys, new_spe_key) - print(spe_dict) - assert np.all(idx2specie != I_INT) + # + # idx2specie = np.array(idx2specie) + # spe_dict = {} + # for i in range(len(idx2specie)): + # spe_idx = idx2specie[i] + # if spe_idx not in spe_dict: + # spe_dict[spe_idx] = 1 + # else: + # spe_dict[spe_idx] += 1 + # + # print(spe_dict) + # assert np.all(idx2specie != I_INT) print(time() - start_time) # print(idx2specie) diff --git a/examples/xor.py b/examples/xor.py index 2b65e8f..61d7398 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -12,7 +12,7 @@ def main(): config = Configer.load_config() problem = Xor() problem.refactor_config(config) - pipeline = Pipeline(config, seed=0) + pipeline = Pipeline(config, seed=1) pipeline.auto_run(problem.evaluate) diff --git a/utils/default_config.json b/utils/default_config.json index 4ba0142..27aaa70 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -3,8 +3,9 @@ "num_inputs": 2, "num_outputs": 1, "problem_batch": 4, - "init_maximum_nodes": 50, - "init_maximum_connections": 50, + "init_maximum_nodes": 20, + "init_maximum_connections": 20, + "init_maximum_species": 10, "expands_coe": 2, "pre_compile_times": 3, "forward_way": "pop_batch" @@ -14,7 +15,7 @@ "fitness_criterion": "max", "fitness_threshold": -0.001, "generation_limit": 1000, - "pop_size": 10000, + "pop_size": 1000, "reset_on_extinction": "False" }, "gene": { @@ -58,12 +59,12 @@ "compatibility_weight_coefficient": 0.5, "single_structural_mutation": "False", "conn_add_prob": 0.5, - "conn_delete_prob": 0.5, + "conn_delete_prob": 0, "node_add_prob": 0.2, - "node_delete_prob": 0.2 + "node_delete_prob": 0 }, "species": { - "compatibility_threshold": 2.5, + "compatibility_threshold": 3, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,