From ba369db0b2ebbd52b5b6d5555c789c93e3a5d21e Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 25 Jun 2023 02:57:45 +0800 Subject: [PATCH] Perfect! Next is to connect with Evox! --- configs/default_config.ini | 10 +-- examples/jax_playground.py | 8 +- examples/xor.ini | 2 +- examples/xor.py | 31 ++----- neat/__init__.py | 3 + neat/function_factory.py | 47 +++++++--- neat/genome/__init__.py | 7 ++ neat/genome/aggregations.py | 11 +-- neat/genome/mutate.py | 5 +- neat/genome/utils.py | 14 +-- neat/operations.py | 171 ++++++++++++++++++++++++++++++++++++ neat/pipeline.py | 118 ++++++++++++++++++++++++- neat/population.py | 168 ----------------------------------- neat/species.py | 65 ++++++++------ 14 files changed, 392 insertions(+), 268 deletions(-) create mode 100644 neat/operations.py delete mode 100644 neat/population.py diff --git a/configs/default_config.ini b/configs/default_config.ini index b45b49f..d4bbcda 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -1,10 +1,10 @@ [basic] num_inputs = 2 num_outputs = 1 -init_maximum_nodes = 20 -init_maximum_connections = 20 +init_maximum_nodes = 50 +init_maximum_connections = 50 init_maximum_species = 10 -expands_coe = 2.0 +expand_coe = 2.0 forward_way = "pop" batch_size = 4 @@ -12,7 +12,7 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 100 fitness_criterion = "max" -pop_size = 150 +pop_size = 15000 [genome] compatibility_disjoint = 1.0 @@ -26,7 +26,7 @@ node_delete_prob = 0 [species] compatibility_threshold = 3.0 species_elitism = 2 -species_max_stagnation = 15 +max_stagnation = 15 genome_elitism = 2 survival_threshold = 0.2 min_species_size = 1 diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 379c28a..0969908 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,20 +1,16 @@ -from functools import partial - import numpy as np -import jax from jax import jit from configs import Configer from neat.pipeline import Pipeline -from neat.function_factory import FunctionFactory xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) def main(): config = Configer.load_config("xor.ini") - function_factory = FunctionFactory(config) - pipeline = Pipeline(config, function_factory) + print(config) + pipeline = Pipeline(config) forward_func = pipeline.ask() # inputs = np.tile(xor_inputs, (150, 1, 1)) outputs = forward_func(xor_inputs) diff --git a/examples/xor.ini b/examples/xor.ini index 79a1110..233ace7 100644 --- a/examples/xor.ini +++ b/examples/xor.ini @@ -2,4 +2,4 @@ forward_way = "common" [population] -fitness_threshold = -1e-2 \ No newline at end of file +fitness_threshold = 3.9999 \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 9ff70ca..b509c76 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,45 +1,26 @@ -from typing import Callable, List -import time - import numpy as np from configs import Configer -from neat import Pipeline +from neat.pipeline import Pipeline -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) -xor_outputs = np.array([[0], [1], [1], [0]]) +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) -def evaluate(forward_func: Callable) -> List[float]: +def evaluate(forward_func): """ :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) :return: """ outs = forward_func(xor_inputs) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - # print(fitnesses) - return fitnesses.tolist() # returns a list + return np.array(fitnesses) # returns a list -# @using_cprofile -# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/") def main(): - tic = time.time() config = Configer.load_config("xor.ini") - print(config) - function_factory = FunctionFactory(config) - pipeline = Pipeline(config, function_factory, seed=6) + pipeline = Pipeline(config, seed=6) nodes, cons = pipeline.auto_run(evaluate) - print(nodes, cons) - total_time = time.time() - tic - compile_time = pipeline.function_factory.compile_time - total_it = pipeline.generation - mean_time_per_it = (total_time - compile_time) / total_it - evaluate_time = pipeline.evaluate_time - print( - f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s") - print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s") - if __name__ == '__main__': main() diff --git a/neat/__init__.py b/neat/__init__.py index e69de29..8f8da22 100644 --- a/neat/__init__.py +++ b/neat/__init__.py @@ -0,0 +1,3 @@ +""" +contains operations on a single genome. e.g. forward, mutate, crossover, etc. +""" \ No newline at end of file diff --git a/neat/function_factory.py b/neat/function_factory.py index effed6a..953cd98 100644 --- a/neat/function_factory.py +++ b/neat/function_factory.py @@ -1,10 +1,8 @@ import numpy as np from jax import jit, vmap -from .genome.forward import create_forward -from .genome.utils import unflatten_connections -from .genome.graph import topological_sort - +from .genome import create_forward, topological_sort, unflatten_connections +from .operations import create_next_generation_then_speciate def hash_symbols(symbols): return symbols['P'], symbols['N'], symbols['C'], symbols['S'] @@ -15,8 +13,10 @@ class FunctionFactory: Creates and compiles functions used in the NEAT pipeline. """ - def __init__(self, config): + def __init__(self, config, jit_config): self.config = config + self.jit_config = jit_config + self.func_dict = {} self.function_info = {} @@ -78,6 +78,24 @@ class FunctionFactory: {'shape': ('P', 'N', 5), 'type': np.float32}, {'shape': ('P', 2, 'N', 'N'), 'type': np.float32} ] + }, + + 'create_next_generation_then_speciate': { + 'func': create_next_generation_then_speciate, + 'lowers': [ + {'shape': (2, ), 'type': np.uint32}, # rand_key + {'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes + {'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons + {'shape': ('P', ), 'type': np.int32}, # winner + {'shape': ('P', ), 'type': np.int32}, # loser + {'shape': ('P', ), 'type': bool}, # elite_mask + {'shape': ('P',), 'type': np.int32}, # new_node_keys + {'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes + {'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons + {'shape': ('S', ), 'type': np.int32}, # species_keys + {'shape': (), 'type': np.int32}, # new_species_key_start + "jit_config" + ] } } @@ -94,12 +112,19 @@ class FunctionFactory: # prepare lower operands lowers_operands = [] for lower in self.function_info[name]['lowers']: - shape = list(lower['shape']) - for i, s in enumerate(shape): - if s in symbols: - shape[i] = symbols[s] - assert isinstance(shape[i], int) - lowers_operands.append(np.zeros(shape, dtype=lower['type'])) + if isinstance(lower, dict): + shape = list(lower['shape']) + for i, s in enumerate(shape): + if s in symbols: + shape[i] = symbols[s] + assert isinstance(shape[i], int) + lowers_operands.append(np.zeros(shape, dtype=lower['type'])) + + elif lower == "jit_config": + lowers_operands.append(self.jit_config) + + else: + raise ValueError("Invalid lower operand") # compile compiled_func = jit(func).lower(*lowers_operands).compile() diff --git a/neat/genome/__init__.py b/neat/genome/__init__.py index e69de29..fd413ee 100644 --- a/neat/genome/__init__.py +++ b/neat/genome/__init__.py @@ -0,0 +1,7 @@ +from .mutate import mutate +from .distance import distance +from .crossover import crossover +from .forward import create_forward +from .graph import topological_sort, check_cycles +from .utils import unflatten_connections +from .genome import initialize_genomes, expand, expand_single \ No newline at end of file diff --git a/neat/genome/aggregations.py b/neat/genome/aggregations.py index ed221f1..a9eb8e6 100644 --- a/neat/genome/aggregations.py +++ b/neat/genome/aggregations.py @@ -1,34 +1,27 @@ -import jax import jax.numpy as jnp -import numpy as np -from jax import jit -@jit + def sum_agg(z): z = jnp.where(jnp.isnan(z), 0, z) return jnp.sum(z, axis=0) -@jit def product_agg(z): z = jnp.where(jnp.isnan(z), 1, z) return jnp.prod(z, axis=0) -@jit def max_agg(z): z = jnp.where(jnp.isnan(z), -jnp.inf, z) return jnp.max(z, axis=0) -@jit def min_agg(z): z = jnp.where(jnp.isnan(z), jnp.inf, z) return jnp.min(z, axis=0) -@jit def maxabs_agg(z): z = jnp.where(jnp.isnan(z), 0, z) abs_z = jnp.abs(z) @@ -36,7 +29,6 @@ def maxabs_agg(z): return z[max_abs_index] -@jit def median_agg(z): non_nan_mask = ~jnp.isnan(z) n = jnp.sum(non_nan_mask, axis=0) @@ -49,7 +41,6 @@ def median_agg(z): return median -@jit def mean_agg(z): non_zero_mask = ~jnp.isnan(z) valid_values_sum = sum_agg(z) diff --git a/neat/genome/mutate.py b/neat/genome/mutate.py index 93c6c15..331dd11 100644 --- a/neat/genome/mutate.py +++ b/neat/genome/mutate.py @@ -10,7 +10,7 @@ import jax from jax import numpy as jnp from jax import jit, Array -from .utils import fetch_random, fetch_first, I_INT +from .utils import fetch_random, fetch_first, I_INT, unflatten_connections from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection from .graph import check_cycles @@ -273,7 +273,8 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config is_already_exist = con_idx != I_INT - is_cycle = check_cycles(nodes, cons, from_idx, to_idx) + u_cons = unflatten_connections(nodes, cons) + is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) diff --git a/neat/genome/utils.py b/neat/genome/utils.py index 9e1ef2f..93fadca 100644 --- a/neat/genome/utils.py +++ b/neat/genome/utils.py @@ -1,12 +1,11 @@ -from functools import partial - +import numpy as np import jax from jax import numpy as jnp, Array from jax import jit, vmap -I_INT = jnp.iinfo(jnp.int32).max # infinite int -EMPTY_NODE = jnp.full((1, 5), jnp.nan) -EMPTY_CON = jnp.full((1, 4), jnp.nan) +I_INT = np.iinfo(jnp.int32).max # infinite int +EMPTY_NODE = np.full((1, 5), jnp.nan) +EMPTY_CON = np.full((1, 4), jnp.nan) @jit @@ -58,8 +57,3 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) -@jit -def argmin_with_mask(arr: Array, mask: Array) -> Array: - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx \ No newline at end of file diff --git a/neat/operations.py b/neat/operations.py new file mode 100644 index 0000000..7984e2a --- /dev/null +++ b/neat/operations.py @@ -0,0 +1,171 @@ +""" +contains operations on the population: creating the next generation and population speciation. +""" +from functools import partial + +import jax +import jax.numpy as jnp +from jax import jit, vmap + +from jax import Array + +from .genome import distance, mutate, crossover +from .genome.utils import I_INT, fetch_first + + +@jit +def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, + center_nodes, center_cons, species_keys, new_species_key_start, + jit_config): + # create next generation + pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, + new_node_keys, jit_config) + + # speciate + idx2specie, spe_center_nodes, spe_center_cons, species_keys = \ + speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config) + + return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys + + +@jit +def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config): + # prepare random keys + pop_size = pop_nodes.shape[0] + k1, k2 = jax.random.split(rand_key, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections + lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections + npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + + # batch mutation + mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None)) + m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) + pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) + + return pop_nodes, pop_cons + + +@jit +def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config): + """ + args: + pop_nodes: (pop_size, N, 5) + pop_cons: (pop_size, C, 4) + spe_center_nodes: (species_size, N, 5) + spe_center_cons: (species_size, C, 4) + """ + pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0] + + # prepare distance functions + o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population + s2p_distance_func = vmap( + o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population + ) + + # idx to specie key + idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species + + # part 1: find new centers + # the distance between each species' center and each genome in population + s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config) + + def find_new_centers(i, carry): + i2s, cn, cc = carry + # find new center + idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) + + # check species[i] exist or not + # if not exist, set idx and i to I_INT, jax will not do array value assignment + idx = jnp.where(species_keys[i] != I_INT, idx, I_INT) + i = jnp.where(species_keys[i] != I_INT, i, I_INT) + + i2s = i2s.at[idx].set(species_keys[i]) + cn = cn.at[i].set(pop_nodes[idx]) + cc = cc.at[i].set(pop_cons[idx]) + return i2s, cn, cc + + idx2specie, center_nodes, center_cons = \ + jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons)) + + # part 2: assign members to each species + def cond_func(carry): + i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key + not_all_assigned = ~jnp.all(i2s != I_INT) + not_reach_species_upper_bounds = i < species_size + return not_all_assigned & not_reach_species_upper_bounds + + def body_func(carry): + i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + + i2s, scn, scc, sk, ck = jax.lax.cond( + sk[i] == I_INT, # whether the current species is existing or not + create_new_specie, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cn, cc, sk, ck) + ) + + return i + 1, i2s, scn, scc, sk, ck + + def create_new_specie(carry): + i, i2s, cn, cc, sk, ck = carry + + # pick the first one who has not been assigned to any species + idx = fetch_first(i2s == I_INT) + + # assign it to the new species + sk = sk.at[i].set(ck) + i2s = i2s.at[idx].set(ck) + + # update center genomes + cn = cn.at[i].set(pop_nodes[idx]) + cc = cc.at[i].set(pop_cons[idx]) + + i2s = speciate_by_threshold((i, i2s, cn, cc, sk)) + return i2s, cn, cc, sk, ck + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, cn, cc, sk, ck = carry + + i2s = speciate_by_threshold((i, i2s, cn, cc, sk)) + + return i2s, cn, cc, sk, ck + + def speciate_by_threshold(carry): + i, i2s, cn, cc, sk = carry + + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) + close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] + + # when it is close enough, assign it to the species, remember not to update genome has already been assigned + i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s) + return i2s + + current_new_key = new_species_key_start + + # update idx2specie + _, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop( + cond_func, + body_func, + (0, idx2specie, center_nodes, center_cons, species_keys, current_new_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition seems to be only happened when the number of species is reached species upper bounds + idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie) + + return idx2specie, center_nodes, center_cons, species_keys + + +@jit +def argmin_with_mask(arr: Array, mask: Array) -> Array: + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx diff --git a/neat/pipeline.py b/neat/pipeline.py index a99fd5b..45c92f4 100644 --- a/neat/pipeline.py +++ b/neat/pipeline.py @@ -1,11 +1,13 @@ -from functools import partial +import time +from typing import Union, Callable import numpy as np import jax -from configs.configer import Configer -from .genome.genome import initialize_genomes +from configs import Configer +from .genome import initialize_genomes, expand, expand_single from .function_factory import FunctionFactory +from .species import SpeciesController class Pipeline: @@ -19,7 +21,7 @@ class Pipeline: self.config = config # global config self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions - self.function_factory = function_factory or FunctionFactory(self.config) + self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config) self.symbols = { 'P': self.config['pop_size'], @@ -31,8 +33,16 @@ class Pipeline: self.generation = 0 self.best_genome = None + self.species_controller = SpeciesController(self.config) self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config) + self.species_controller.init_speciate(self.pop_nodes, self.pop_cons) + self.best_fitness = float('-inf') + self.best_genome = None + self.generation_timestamp = time.time() + + self.evaluate_time = 0 + print(self.config) def ask(self): """ @@ -74,5 +84,105 @@ class Pipeline: else: raise NotImplementedError + + def tell(self, fitnesses): + self.generation += 1 + + winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \ + self.species_controller.ask(fitnesses, self.generation, self.symbols) + + # node keys to be used in the mutation process + new_node_keys = np.arange(self.generation * self.config['pop_size'], + self.generation * self.config['pop_size'] + self.config['pop_size']) + + # create the next generation and then speciate the population + self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ + self.get_func('create_next_generation_then_speciate') \ + (self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes, + center_cons, species_keys, species_key_start, self.jit_config) + + # carry data to cpu + self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ + jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys]) + + self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation) + + # expand the population if needed + self.expand() + + # update randkey + self.randkey = jax.random.split(self.randkey)[0] + + def expand(self): + """ + Expand the population if needed. + when the maximum node number >= N or the maximum connection number of >= C + the population will expand + """ + changed = False + + pop_node_keys = self.pop_nodes[:, :, 0] + pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1) + max_node_size = np.max(pop_node_sizes) + if max_node_size >= self.symbols['N']: + self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe']) + print(f"node expand to {self.symbols['N']}!") + changed = True + + pop_con_keys = self.pop_cons[:, :, 0] + pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) + max_con_size = np.max(pop_node_sizes) + if max_con_size >= self.symbols['C']: + self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe']) + print(f"connection expand to {self.symbols['C']}!") + changed = True + + if changed: + self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C']) + # don't forget to expand representation genome in species + for s in self.species_controller.species.values(): + s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C']) + def get_func(self, name): return self.function_factory.get(name, self.symbols) + + def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): + for _ in range(self.config['generation_limit']): + forward_func = self.ask() + + tic = time.time() + fitnesses = fitness_func(forward_func) + self.evaluate_time += time.time() - tic + + assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" + + if analysis is not None: + if analysis == "default": + self.default_analysis(fitnesses) + else: + assert callable(analysis), f"What the fuck you passed in? A {analysis}?" + analysis(fitnesses) + + if max(fitnesses) >= self.config['fitness_threshold']: + print("Fitness limit reached!") + return self.best_genome + + self.tell(fitnesses) + print("Generation limit reached!") + return self.best_genome + + def default_analysis(self, fitnesses): + max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) + species_sizes = [len(s.members) for s in self.species_controller.species.values()] + + new_timestamp = time.time() + cost_time = new_timestamp - self.generation_timestamp + self.generation_timestamp = new_timestamp + + max_idx = np.argmax(fitnesses) + if fitnesses[max_idx] > self.best_fitness: + self.best_fitness = fitnesses[max_idx] + self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) + + print(f"Generation: {self.generation}", + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") diff --git a/neat/population.py b/neat/population.py deleted file mode 100644 index 9ad0bb6..0000000 --- a/neat/population.py +++ /dev/null @@ -1,168 +0,0 @@ -from functools import partial - -import jax -import jax.numpy as jnp -from jax import jit, vmap - -from jax import Array - -from .genome import distance, mutate, crossover -from .genome.utils import I_INT, fetch_first, argmin_with_mask - - -@jit -def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, - new_node_keys, - pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start, - species_kwargs, mutate_kwargs): - # create next generation - pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, - new_node_keys, **mutate_kwargs) - - # speciate - idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes, - pre_spe_center_cons, species_keys, - new_species_key_start, **species_kwargs) - - return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys - - -@jit -def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array, - species_keys, new_species_key_start, - disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0 - ): - """ - args: - pop_nodes: (pop_size, N, 5) - pop_cons: (pop_size, C, 4) - spe_center_nodes: (species_size, N, 5) - spe_center_cons: (species_size, C, 4) - """ - pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0] - - # prepare distance functions - distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe) - o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) - s2p_distance_func = vmap( - o2p_distance_func, in_axes=(0, 0, None, None) - ) - - # idx to specie key - idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species - - # part 1: find new centers - # the distance between each species' center and each genome in population - s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons) - - def find_new_centers(i, carry): - i2s, scn, scc = carry - # find new center - idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) - - # check species[i] exist or not - # if not exist, set idx and i to I_INT, jax will not do array value assignment - idx = jnp.where(species_keys[i] != I_INT, idx, I_INT) - i = jnp.where(species_keys[i] != I_INT, i, I_INT) - - i2s = i2s.at[idx].set(species_keys[i]) - scn = scn.at[i].set(pop_nodes[idx]) - scc = scc.at[i].set(pop_cons[idx]) - return i2s, scn, scc - - idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons)) - - def continue_execute_while(carry): - i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key - not_all_assigned = ~jnp.all(i2s != I_INT) - not_reach_species_upper_bounds = i < species_size - return not_all_assigned & not_reach_species_upper_bounds - - def deal_with_each_center_genome(carry): - i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i] - - i2s, scn, scc, sk, ck = jax.lax.cond( - jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid - create_new_specie, # if not valid, create a new specie - update_exist_specie, # if valid, update the specie - (i, i2s, scn, scc, sk, ck) - ) - - return i + 1, i2s, scn, scc, sk, ck - - def create_new_specie(carry): - i, i2s, scn, scc, sk, ck = carry - # pick the first one who has not been assigned to any species - idx = fetch_first(i2s == I_INT) - - # assign it to new specie - sk = sk.at[i].set(ck) - i2s = i2s.at[idx].set(ck) - - # update center genomes - scn = scn.at[i].set(pop_nodes[idx]) - scc = scc.at[i].set(pop_cons[idx]) - - i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk)) - return i2s, scn, scc, sk, ck + 1 # change to next new speciate key - - def update_exist_specie(carry): - i, i2s, scn, scc, sk, ck = carry - - i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk)) - return i2s, scn, scc, sk, ck - - def speciate_by_threshold(carry): - i, i2s, scn, scc, sk = carry - # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons) - close_enough_mask = o2p_distance < compatibility_threshold - - # when it is close enough, assign it to the species, remember not to update genome has already been assigned - i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s) - return i2s, scn, scc - - current_new_key = new_species_key_start - - # update idx2specie - _, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop( - continue_execute_while, - deal_with_each_center_genome, - (0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition seems to be only happened when the number of species is reached species upper bounds - idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie) - - return idx2specie, spe_center_nodes, spe_center_cons, species_keys - - -@jit -def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys, - **mutate_kwargs): - # prepare functions - batch_crossover = vmap(crossover) - mutate_with_args = vmap(partial(mutate, **mutate_kwargs)) - - pop_size = pop_nodes.shape[0] - k1, k2 = jax.random.split(rand_key, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn = pop_nodes[winner_part] # winner pop nodes - wpc = pop_cons[winner_part] # winner pop connections - lpn = pop_nodes[loser_part] # loser pop nodes - lpc = pop_cons[loser_part] # loser pop connections - - npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - - m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) - pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) - - return pop_nodes, pop_cons diff --git a/neat/species.py b/neat/species.py index ff3e015..ce7415a 100644 --- a/neat/species.py +++ b/neat/species.py @@ -1,7 +1,15 @@ -from typing import List, Tuple, Dict, Union, Callable +""" +Species Controller in NEAT. +The code are modified from neat-python. +See + https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation + https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction + https://neat-python.readthedocs.io/en/latest/module_summaries.html#species +""" + +from typing import List, Tuple, Dict from itertools import count -import jax import numpy as np from numpy.typing import NDArray @@ -37,14 +45,13 @@ class SpeciesController: def __init__(self, config): self.config = config - self.species_elitism = self.config.neat.species.species_elitism - self.pop_size = self.config.neat.population.pop_size - self.max_stagnation = self.config.neat.species.max_stagnation - self.min_species_size = self.config.neat.species.min_species_size - self.genome_elitism = self.config.neat.species.genome_elitism - self.survival_threshold = self.config.neat.species.survival_threshold + self.species_elitism = self.config['species_elitism'] + self.pop_size = self.config['pop_size'] + self.max_stagnation = self.config['max_stagnation'] + self.min_species_size = self.config['min_species_size'] + self.genome_elitism = self.config['genome_elitism'] + self.survival_threshold = self.config['survival_threshold'] - self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray): @@ -55,9 +62,10 @@ class SpeciesController: :return: """ pop_size = pop_nodes.shape[0] - species_id = next(self.species_idxer) + species_id = 0 # the first species s = Species(species_id, 0) members = np.array(list(range(pop_size))) + s.update((pop_nodes[0], pop_connections[0]), members) self.species[species_id] = s @@ -68,16 +76,14 @@ class SpeciesController: :return: """ for sid, s in self.species.items(): - # TODO: here use mean to measure the fitness of a species, but it may be other functions s.member_fitnesses = s.get_fitnesses(fitnesses) - # s.fitness = np.mean(s.member_fitnesses) + # use the max score to represent the fitness of the species s.fitness = np.max(s.member_fitnesses) s.fitness_history.append(s.fitness) s.adjusted_fitness = None def __stagnation(self, generation): """ - code modified from neat-python! :param generation: :return: whether the species is stagnated """ @@ -88,7 +94,7 @@ class SpeciesController: else: prev_fitness = float('-inf') - if prev_fitness is None or s.fitness > prev_fitness: + if s.fitness > prev_fitness: s.last_improved = generation species_data.append((sid, s)) @@ -110,7 +116,6 @@ class SpeciesController: def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: """ - code modified from neat-python! :param fitnesses: :param generation: :return: crossover_pair for next generation. @@ -136,6 +141,8 @@ class SpeciesController: # No species left. assert remaining_species + + # TODO: Too complex! # Compute each species' member size in the next generation. # Do not allow the fitness range to be zero, as we divide by it below. @@ -185,6 +192,7 @@ class SpeciesController: # only use good genomes to crossover sorted_members = sorted_members[:repro_cutoff] + # TODO: Genome with higher fitness should be more likely to be selected? list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True) part1.extend(sorted_members[list_idx1]) part2.extend(sorted_members[list_idx2]) @@ -197,32 +205,37 @@ class SpeciesController: return winner_part, loser_part, np.array(elite_mask) - def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation): + def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation): for idx, key in enumerate(species_keys): if key == I_INT: continue members = np.where(idx2specie == key)[0] assert len(members) > 0 + if key not in self.species: + # the new specie created in this generation s = Species(key, generation) self.species[key] = s - self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members) + self.species[key].update((center_nodes[idx], center_cons[idx]), members) - def ask(self, fitnesses, generation, S, N, C): + def ask(self, fitnesses, generation, symbols): self.__update_species_fitnesses(fitnesses) - winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation) - pre_spe_center_nodes = np.full((S, N, 5), np.nan) - pre_spe_center_cons = np.full((S, C, 4), np.nan) - species_keys = np.full((S,), I_INT) + + winner, loser, elite_mask = self.__reproduce(fitnesses, generation) + + center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan) + center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan) + species_keys = np.full((symbols['S'], ), I_INT) + for idx, (key, specie) in enumerate(self.species.items()): - pre_spe_center_nodes[idx] = specie.representative[0] - pre_spe_center_cons[idx] = specie.representative[1] + center_nodes[idx], center_cons[idx] = specie.representative species_keys[idx] = key + next_new_specie_key = max(self.species.keys()) + 1 - return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \ - pre_spe_center_cons, species_keys, next_new_specie_key + + return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):