diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py index f1c4384..d94a11a 100644 --- a/algorithms/neat/__init__.py +++ b/algorithms/neat/__init__.py @@ -1,6 +1,5 @@ """ contains operations on a single genome. e.g. forward, mutate, crossover, etc. """ -from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single -from .operations import create_next_generation_then_speciate -from .species import SpeciesController +from .genome import create_forward_function, topological_sort, unflatten_connections +from .population import update_species, create_next_generation, speciate diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index fd413ee..97f89a0 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,7 +1,6 @@ from .mutate import mutate from .distance import distance from .crossover import crossover -from .forward import create_forward from .graph import topological_sort, check_cycles from .utils import unflatten_connections -from .genome import initialize_genomes, expand, expand_single \ No newline at end of file +from .forward import create_forward_function diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 40f7106..9a4b161 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -5,7 +5,7 @@ from jax import jit, vmap from .utils import I_INT -def create_forward(config): +def create_forward_function(config): """ meta method to create forward function """ @@ -83,4 +83,22 @@ def create_forward(config): return vals[output_idx] + # (batch_size, inputs_nums) -> (batch_size, outputs_nums) + batch_forward = vmap(forward, in_axes=(0, None, None, None)) + + # (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) + pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0)) + + # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) + common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) + + if config['forward_way'] == 'single': + return jit(batch_forward) + + elif config['forward_way'] == 'pop': + return jit(pop_batch_forward) + + elif config['forward_way'] == 'common': + return jit(common_forward) + return forward diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 4f2d32b..2679289 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -65,55 +65,6 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: return pop_nodes, pop_cons -def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: - """ - Expand a single genome to accommodate more nodes or connections. - :param nodes: (N, 5) - :param cons: (C, 4) - :param new_N: - :param new_C: - :return: (new_N, 5), (new_C, 4) - """ - old_N, old_C = nodes.shape[0], cons.shape[0] - new_nodes = np.full((new_N, 5), np.nan) - new_nodes[:old_N, :] = nodes - - new_cons = np.full((new_C, 4), np.nan) - new_cons[:old_C, :] = cons - - return new_nodes, new_cons - - -def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: - """ - Expand the population to accommodate more nodes or connections. - :param pop_nodes: (pop_size, N, 5) - :param pop_cons: (pop_size, C, 4) - :param new_N: - :param new_C: - :return: (pop_size, new_N, 5), (pop_size, new_C, 4) - """ - pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1] - - new_pop_nodes = np.full((pop_size, new_N, 5), np.nan) - new_pop_nodes[:, :old_N, :] = pop_nodes - - new_pop_cons = np.full((pop_size, new_C, 4), np.nan) - new_pop_cons[:, :old_C, :] = pop_cons - - return new_pop_nodes, new_pop_cons - - -@jit -def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]: - """ - Count how many nodes and connections are in the genome. - """ - node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) - cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) - return node_cnt, cons_cnt - - @jit def add_node(nodes: NDArray, cons: NDArray, new_key: int, bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]: diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 4a984c3..9ae2e95 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) + @partial(jit, static_argnames=['reverse']) def rank_elements(array, reverse=False): """ rank the element in the array. - if reverse is True, the rank is from large to small. + if reverse is True, the rank is from small to large. default large to small """ - if reverse: + if not reverse: array = -array return jnp.argsort(jnp.argsort(array)) \ No newline at end of file diff --git a/algorithms/neat/jit_species.py b/algorithms/neat/jit_species.py deleted file mode 100644 index a4f4f44..0000000 --- a/algorithms/neat/jit_species.py +++ /dev/null @@ -1,160 +0,0 @@ -from functools import partial - -import jax -from jax import jit, numpy as jnp, vmap - -from .genome.utils import rank_elements - - -@jit -def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): - """ - args: - randkey: random key - fitness: Array[(pop_size,), float], the fitness of each individual - species_keys: Array[(species_size, 3), float], the information of each species - [species_key, best_score, last_update] - idx2species: Array[(pop_size,), int], map the individual to its species - center_nodes: Array[(species_size, N, 4), float], the center nodes of each species - center_cons: Array[(species_size, C, 4), float], the center connections of each species - generation: int, current generation - jit_config: Dict, the configuration of jit functions - """ - - # update the fitness of each species - species_fitness = update_species_fitness(species_info, idx2species, fitness) - - # stagnation species - species_fitness, species_info, center_nodes, center_cons = \ - stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config) - - # sort species_info by their fitness. (push nan to the end) - sort_indices = jnp.argsort(species_fitness)[::-1] - species_info = species_info[sort_indices] - center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices] - - # decide the number of members of each species by their fitness - spawn_number = cal_spawn_numbers(species_info, jit_config) - - # crossover info - winner, loser, elite_mask = \ - create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config) - - jax.debug.print("{}, {}", fitness, winner) - jax.debug.print("{}", fitness[winner]) - - return species_info, center_nodes, center_cons, winner, loser, elite_mask - - -def update_species_fitness(species_info, idx2species, fitness): - """ - obtain the fitness of the species by the fitness of each individual. - use max criterion. - """ - - def aux_func(idx): - species_key = species_info[idx, 0] - s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf) - f = jnp.max(s_fitness) - return f - - return vmap(aux_func)(jnp.arange(species_info.shape[0])) - - -def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config): - """ - stagnation species. - those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. - elitism species never stagnation - """ - - def aux_func(idx): - s_fitness = species_fitness[idx] - species_key, best_score, last_update = species_info[idx] - # stagnation condition - return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation']) - - st = vmap(aux_func)(jnp.arange(species_info.shape[0])) - - # elite species will not be stagnation - species_rank = rank_elements(species_fitness) - st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation - - # set stagnation species to nan - species_info = jnp.where(st[:, None], jnp.nan, species_info) - center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes) - center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons) - species_fitness = jnp.where(st, jnp.nan, species_fitness) - - return species_fitness, species_info, center_nodes, center_cons - - -def cal_spawn_numbers(species_info, jit_config): - """ - decide the number of members of each species by their fitness rank. - the species with higher fitness will have more members - Linear ranking selection - e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] - """ - - is_species_valid = ~jnp.isnan(species_info[:, 0]) - valid_species_num = jnp.sum(is_species_valid) - denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 - - rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1] - spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 - - spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member - - # must control the sum of spawn_number to be equal to pop_size - error = jit_config['pop_size'] - jnp.sum(spawn_number) - spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number - - return spawn_number - - -def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config): - - species_size = species_info.shape[0] - pop_size = fitness.shape[0] - s_idx = jnp.arange(species_size) - p_idx = jnp.arange(pop_size) - - def aux_func(key, idx): - members = idx2species == species_info[idx, 0] - members_num = jnp.sum(members) - - members_fitness = jnp.where(members, fitness, jnp.nan) - sorted_member_indices = jnp.argsort(members_fitness)[::-1] - - elite_size = jit_config['genome_elitism'] - survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32) - - select_pro = (p_idx < survive_size) / survive_size - fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) - - # elite - fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) - ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) - elite = jnp.where(p_idx < elite_size, True, False) - return fa, ma, elite - - fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) - - spawn_number_cum = jnp.cumsum(spawn_number) - - def aux_func(idx): - loc = jnp.argmax(idx < spawn_number_cum) - - # elite genomes are at the beginning of the species - idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) - return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] - - part1, part2, elite_mask = vmap(aux_func)(p_idx) - - is_part1_win = fitness[part1] >= fitness[part2] - winner = jnp.where(is_part1_win, part1, part2) - loser = jnp.where(is_part1_win, part2, part1) - - return winner, loser, elite_mask diff --git a/algorithms/neat/operations.py b/algorithms/neat/operations.py deleted file mode 100644 index bb0a758..0000000 --- a/algorithms/neat/operations.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -contains operations on the population: creating the next generation and population speciation. -""" -import jax -from jax import jit, vmap, Array, numpy as jnp - -from .genome import distance, mutate, crossover -from .genome.utils import I_INT, fetch_first - - -@jit -def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, - center_nodes, center_cons, species_keys, new_species_key_start, - jit_config): - # create next generation - pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, - new_node_keys, jit_config) - - # speciate - idx2specie, spe_center_nodes, spe_center_cons, species_keys = \ - speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config) - - return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys - - -@jit -def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config): - # prepare random keys - pop_size = pop_nodes.shape[0] - k1, k2 = jax.random.split(rand_key, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections - lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections - npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - - # batch mutation - mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None)) - m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) - pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) - - return pop_nodes, pop_cons - - -@jit -def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config): - """ - args: - pop_nodes: (pop_size, N, 5) - pop_cons: (pop_size, C, 4) - spe_center_nodes: (species_size, N, 5) - spe_center_cons: (species_size, C, 4) - """ - pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0] - - # prepare distance functions - o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population - s2p_distance_func = vmap( - o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population - ) - - # idx to specie key - idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species - - # part 1: find new centers - # the distance between each species' center and each genome in population - s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config) - - def find_new_centers(i, carry): - i2s, cn, cc = carry - # find new center - idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) - - # check species[i] exist or not - # if not exist, set idx and i to I_INT, jax will not do array value assignment - idx = jnp.where(species_keys[i] != I_INT, idx, I_INT) - i = jnp.where(species_keys[i] != I_INT, i, I_INT) - - i2s = i2s.at[idx].set(species_keys[i]) - cn = cn.at[i].set(pop_nodes[idx]) - cc = cc.at[i].set(pop_cons[idx]) - return i2s, cn, cc - - idx2specie, center_nodes, center_cons = \ - jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons)) - - # part 2: assign members to each species - def cond_func(carry): - i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key - not_all_assigned = ~jnp.all(i2s != I_INT) - not_reach_species_upper_bounds = i < species_size - return not_all_assigned & not_reach_species_upper_bounds - - def body_func(carry): - i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - - i2s, scn, scc, sk, ck = jax.lax.cond( - sk[i] == I_INT, # whether the current species is existing or not - create_new_specie, # if not existing, create a new specie - update_exist_specie, # if existing, update the specie - (i, i2s, cn, cc, sk, ck) - ) - - return i + 1, i2s, scn, scc, sk, ck - - def create_new_specie(carry): - i, i2s, cn, cc, sk, ck = carry - - # pick the first one who has not been assigned to any species - idx = fetch_first(i2s == I_INT) - - # assign it to the new species - sk = sk.at[i].set(ck) - i2s = i2s.at[idx].set(ck) - - # update center genomes - cn = cn.at[i].set(pop_nodes[idx]) - cc = cc.at[i].set(pop_cons[idx]) - - i2s = speciate_by_threshold((i, i2s, cn, cc, sk)) - return i2s, cn, cc, sk, ck + 1 # change to next new speciate key - - def update_exist_specie(carry): - i, i2s, cn, cc, sk, ck = carry - - i2s = speciate_by_threshold((i, i2s, cn, cc, sk)) - - return i2s, cn, cc, sk, ck - - def speciate_by_threshold(carry): - i, i2s, cn, cc, sk = carry - - # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) - close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] - - # when it is close enough, assign it to the species, remember not to update genome has already been assigned - i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s) - return i2s - - current_new_key = new_species_key_start - - # update idx2specie - _, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop( - cond_func, - body_func, - (0, idx2specie, center_nodes, center_cons, species_keys, current_new_key) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition seems to be only happened when the number of species is reached species upper bounds - idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie) - - return idx2specie, center_nodes, center_cons, species_keys - - -@jit -def argmin_with_mask(arr: Array, mask: Array) -> Array: - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx diff --git a/jit_pipeline.py b/algorithms/neat/pipeline.py similarity index 52% rename from jit_pipeline.py rename to algorithms/neat/pipeline.py index 8bb03b0..ac2ea3f 100644 --- a/jit_pipeline.py +++ b/algorithms/neat/pipeline.py @@ -3,13 +3,13 @@ from typing import Union, Callable import numpy as np import jax +from jax import jit, vmap from configs import Configer -from function_factory import FunctionFactory -from algorithms.neat import initialize_genomes, expand, expand_single +from algorithms.neat import initialize_genomes -from algorithms.neat.jit_species import update_species -from algorithms.neat.operations import create_next_generation_then_speciate +from algorithms.neat.population import create_next_generation, speciate, update_species +from algorithms.neat import unflatten_connections, topological_sort, create_forward_function class Pipeline: @@ -17,30 +17,27 @@ class Pipeline: Neat algorithm pipeline. """ - def __init__(self, config, function_factory=None, seed=42): + def __init__(self, config, seed=42): self.randkey = jax.random.PRNGKey(seed) np.random.seed(seed) self.config = config # global config self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions - self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config) - self.symbols = { - 'P': self.config['pop_size'], - 'N': self.config['init_maximum_nodes'], - 'C': self.config['init_maximum_connections'], - 'S': self.config['init_maximum_species'], - } + self.P = config['pop_size'] + self.N = config['init_maximum_nodes'] + self.C = config['init_maximum_connections'] + self.S = config['init_maximum_species'] self.generation = 0 self.best_genome = None - self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config) - self.species_info = np.full((self.symbols['S'], 3), np.nan) + self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config) + self.species_info = np.full((self.S, 3), np.nan) self.species_info[0, :] = 0, -np.inf, 0 - self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32) - self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan) - self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan) + self.idx2species = np.zeros(self.P, dtype=np.float32) + self.center_nodes = np.full((self.S, self.N, 5), np.nan) + self.center_cons = np.full((self.S, self.C, 4), np.nan) self.center_nodes[0, :, :] = self.pop_nodes[0, :, :] self.center_cons[0, :, :] = self.pop_cons[0, :, :] @@ -49,7 +46,10 @@ class Pipeline: self.generation_timestamp = time.time() self.evaluate_time = 0 - print(self.config) + + self.pop_unflatten_connections = jit(vmap(unflatten_connections)) + self.pop_topological_sort = jit(vmap(topological_sort)) + self.forward = create_forward_function(config) def ask(self): """ @@ -71,52 +71,28 @@ class Pipeline: e.g. numerical regression; Hyper-NEAT """ - u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons) - pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons) + u_pop_cons = self.pop_unflatten_connections(self.pop_nodes, self.pop_cons) + pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons) - if self.config['forward_way'] == 'single': - forward_funcs = [] - for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons): - func = lambda x: self.get_func('forward')(x, seq, nodes, cons) - forward_funcs.append(func) - return forward_funcs - - elif self.config['forward_way'] == 'pop': - func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) - return func - - elif self.config['forward_way'] == 'common': - func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) - return func - - else: - raise NotImplementedError + # only common mode is supported currently + assert self.config['forward_way'] == 'common' + return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) def tell(self, fitnesses): self.generation += 1 - species_info, center_nodes, center_cons, winner, loser, elite_mask = \ - update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes, + k1, k2, self.randkey = jax.random.split(self.randkey, 3) + + self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \ + update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes, self.center_cons, self.generation, self.jit_config) - # node keys to be used in the mutation process - new_node_keys = np.arange(self.generation * self.config['pop_size'], - self.generation * self.config['pop_size'] + self.config['pop_size']) + self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser, + elite_mask, self.generation, self.jit_config) - # create the next generation and then speciate the population - self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ - create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes, - center_cons, species_keys, species_key_start, self.jit_config) - - # carry data to cpu - self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ - jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys]) - - # update randkey - self.randkey = jax.random.split(self.randkey)[0] - - def get_func(self, name): - return self.function_factory.get(name, self.symbols) + self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate( + self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation, + self.jit_config) def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']): diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py new file mode 100644 index 0000000..f8a1f96 --- /dev/null +++ b/algorithms/neat/population.py @@ -0,0 +1,307 @@ +""" +contains operations on the population: creating the next generation and population speciation. +""" +import jax +from jax import jit, vmap, Array, numpy as jnp + +from .genome import distance, mutate, crossover +from .genome.utils import I_INT, fetch_first, rank_elements + + +@jit +def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): + """ + args: + randkey: random key + fitness: Array[(pop_size,), float], the fitness of each individual + species_keys: Array[(species_size, 3), float], the information of each species + [species_key, best_score, last_update] + idx2species: Array[(pop_size,), int], map the individual to its species + center_nodes: Array[(species_size, N, 4), float], the center nodes of each species + center_cons: Array[(species_size, C, 4), float], the center connections of each species + generation: int, current generation + jit_config: Dict, the configuration of jit functions + """ + + # update the fitness of each species + species_fitness = update_species_fitness(species_info, idx2species, fitness) + + # stagnation species + species_fitness, species_info, center_nodes, center_cons = \ + stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config) + + # sort species_info by their fitness. (push nan to the end) + sort_indices = jnp.argsort(species_fitness)[::-1] + species_info = species_info[sort_indices] + center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices] + + # decide the number of members of each species by their fitness + spawn_number = cal_spawn_numbers(species_info, jit_config) + + # crossover info + winner, loser, elite_mask = \ + create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config) + + return species_info, center_nodes, center_cons, winner, loser, elite_mask + + +def update_species_fitness(species_info, idx2species, fitness): + """ + obtain the fitness of the species by the fitness of each individual. + use max criterion. + """ + + def aux_func(idx): + species_key = species_info[idx, 0] + s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf) + f = jnp.max(s_fitness) + return f + + return vmap(aux_func)(jnp.arange(species_info.shape[0])) + + +def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config): + """ + stagnation species. + those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. + elitism species never stagnation + """ + + def aux_func(idx): + s_fitness = species_fitness[idx] + species_key, best_score, last_update = species_info[idx] + st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation']) + last_update = jnp.where(s_fitness > best_score, generation, last_update) + best_score = jnp.where(s_fitness > best_score, s_fitness, best_score) + # stagnation condition + return st, jnp.array([species_key, best_score, last_update]) + + spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0])) + + # elite species will not be stagnation + species_rank = rank_elements(species_fitness) + spe_st = jnp.where(species_rank < jit_config['species_elitism'], False, spe_st) # elitism never stagnation + + # set stagnation species to nan + species_info = jnp.where(spe_st[:, None], jnp.nan, species_info) + center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes) + center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons) + species_fitness = jnp.where(spe_st, jnp.nan, species_fitness) + + return species_fitness, species_info, center_nodes, center_cons + + +def cal_spawn_numbers(species_info, jit_config): + """ + decide the number of members of each species by their fitness rank. + the species with higher fitness will have more members + Linear ranking selection + e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] + """ + + is_species_valid = ~jnp.isnan(species_info[:, 0]) + valid_species_num = jnp.sum(is_species_valid) + denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + + rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1] + spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] + spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + + spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member + + # must control the sum of spawn_number to be equal to pop_size + error = jit_config['pop_size'] - jnp.sum(spawn_number) + spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number + + return spawn_number + + +def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config): + + species_size = species_info.shape[0] + pop_size = fitness.shape[0] + s_idx = jnp.arange(species_size) + p_idx = jnp.arange(pop_size) + + # def aux_func(key, idx): + def aux_func(key, idx): + members = idx2species == species_info[idx, 0] + members_num = jnp.sum(members) + + members_fitness = jnp.where(members, fitness, -jnp.inf) + sorted_member_indices = jnp.argsort(members_fitness)[::-1] + + elite_size = jit_config['genome_elitism'] + survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32) + + select_pro = (p_idx < survive_size) / survive_size + fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) + + # elite + fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) + ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) + elite = jnp.where(p_idx < elite_size, True, False) + return fa, ma, elite + + # fas, mas, elites = jax.lax.max(aux_func, (jax.random.split(randkey, species_size), s_idx)) + fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) + + spawn_number_cum = jnp.cumsum(spawn_number) + + def aux_func(idx): + loc = jnp.argmax(idx < spawn_number_cum) + + # elite genomes are at the beginning of the species + idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) + return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + + part1, part2, elite_mask = vmap(aux_func)(p_idx) + + is_part1_win = fitness[part1] >= fitness[part2] + winner = jnp.where(is_part1_win, part1, part2) + loser = jnp.where(is_part1_win, part2, part1) + + return winner, loser, elite_mask + + +@jit +def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config): + # prepare random keys + pop_size = pop_nodes.shape[0] + new_node_keys = jnp.arange(pop_size) + generation * pop_size + + k1, k2 = jax.random.split(rand_key, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections + lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections + npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + + # batch mutation + mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None)) + m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) + pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) + + return pop_nodes, pop_cons + + +@jit +def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config): + """ + args: + pop_nodes: (pop_size, N, 5) + pop_cons: (pop_size, C, 4) + spe_center_nodes: (species_size, N, 5) + spe_center_cons: (species_size, C, 4) + """ + pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0] + + # prepare distance functions + o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population + s2p_distance_func = vmap( + o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population + ) + + # idx to specie key + idx2specie = jnp.full((pop_size,), jnp.nan) # I_INT means not assigned to any species + + # part 1: find new centers + # the distance between each species' center and each genome in population + s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config) + + def find_new_centers(i, carry): + i2s, cn, cc = carry + # find new center + idx = argmin_with_mask(s2p_distance[i], mask=jnp.isnan(i2s)) + + # check species[i] exist or not + # if not exist, set idx and i to I_INT, jax will not do array value assignment + idx = jnp.where(~jnp.isnan(species_info[i, 0]), idx, I_INT) + i = jnp.where(~jnp.isnan(species_info[i, 0]), i, I_INT) + + i2s = i2s.at[idx].set(species_info[i, 0]) + cn = cn.at[i].set(pop_nodes[idx]) + cc = cc.at[i].set(pop_cons[idx]) + return i2s, cn, cc + + idx2specie, center_nodes, center_cons = \ + jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons)) + + # part 2: assign members to each species + def cond_func(carry): + i, i2s, cn, cc, si, ck = carry # si is short for species_info, ck is short for current key + not_all_assigned = jnp.any(jnp.isnan(i2s)) + not_reach_species_upper_bounds = i < species_size + return not_all_assigned & not_reach_species_upper_bounds + + def body_func(carry): + i, i2s, cn, cc, si, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + + i2s, scn, scc, si, ck = jax.lax.cond( + jnp.isnan(si[i, 0]), # whether the current species is existing or not + create_new_specie, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cn, cc, si, ck) + ) + + return i + 1, i2s, scn, scc, si, ck + + def create_new_specie(carry): + i, i2s, cn, cc, si, ck = carry + + # pick the first one who has not been assigned to any species + idx = fetch_first(jnp.isnan(i2s)) + + # assign it to the new species + si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation] + i2s = i2s.at[idx].set(ck) + + # update center genomes + cn = cn.at[i].set(pop_nodes[idx]) + cc = cc.at[i].set(pop_cons[idx]) + + i2s = speciate_by_threshold((i, i2s, cn, cc, si)) + return i2s, cn, cc, si, ck + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, cn, cc, si, ck = carry + i2s = speciate_by_threshold((i, i2s, cn, cc, si)) + return i2s, cn, cc, si, ck + + def speciate_by_threshold(carry): + i, i2s, cn, cc, si = carry + + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) + close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] + + # when it is close enough, assign it to the species, remember not to update genome has already been assigned + i2s = jnp.where(close_enough_mask & jnp.isnan(i2s), si[i, 0], i2s) + return i2s + + species_keys = species_info[:, 0] + current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1 + + # update idx2specie + _, idx2specie, center_nodes, center_cons, species_info, _ = jax.lax.while_loop( + cond_func, + body_func, + (0, idx2specie, center_nodes, center_cons, species_info, current_new_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition seems to be only happened when the number of species is reached species upper bounds + idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie) + return idx2specie, center_nodes, center_cons, species_info + + +@jit +def argmin_with_mask(arr: Array, mask: Array) -> Array: + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py deleted file mode 100644 index 32952a4..0000000 --- a/algorithms/neat/species.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Species Controller in NEAT. -The code are modified from neat-python. -See - https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation - https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction - https://neat-python.readthedocs.io/en/latest/module_summaries.html#species -""" - -from typing import List, Tuple, Dict - -import numpy as np -from numpy.typing import NDArray - -from .genome.utils import I_INT - - -class Species(object): - - def __init__(self, key, generation): - self.key = key - self.created = generation - self.last_improved = generation - self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections) - self.members: NDArray = None # idx in pop_nodes, pop_connections, - self.fitness = None - self.member_fitnesses = None - self.adjusted_fitness = None - self.fitness_history: List[float] = [] - - def update(self, representative, members): - self.representative = representative - self.members = members - - def get_fitnesses(self, fitnesses): - return fitnesses[self.members] - - -class SpeciesController: - """ - A class to control the species - """ - - def __init__(self, config): - self.config = config - - self.species_elitism = self.config['species_elitism'] - self.pop_size = self.config['pop_size'] - self.max_stagnation = self.config['max_stagnation'] - self.min_species_size = self.config['min_species_size'] - self.genome_elitism = self.config['genome_elitism'] - self.survival_threshold = self.config['survival_threshold'] - - self.species: Dict[int, Species] = {} # species_id -> species - - def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray): - """ - speciate for the first generation - :param pop_connections: - :param pop_nodes: - :return: - """ - pop_size = pop_nodes.shape[0] - species_id = 0 # the first species - s = Species(species_id, 0) - members = np.array(list(range(pop_size))) - - s.update((pop_nodes[0], pop_connections[0]), members) - self.species[species_id] = s - - def __update_species_fitnesses(self, fitnesses): - """ - update the fitness of each species - :param fitnesses: - :return: - """ - for sid, s in self.species.items(): - s.member_fitnesses = s.get_fitnesses(fitnesses) - # use the max score to represent the fitness of the species - s.fitness = np.max(s.member_fitnesses) - s.fitness_history.append(s.fitness) - s.adjusted_fitness = None - - def __stagnation(self, generation): - """ - :param generation: - :return: whether the species is stagnated - """ - species_data = [] - for sid, s in self.species.items(): - if s.fitness_history: - prev_fitness = max(s.fitness_history) - else: - prev_fitness = float('-inf') - - if s.fitness > prev_fitness: - s.last_improved = generation - - species_data.append((sid, s)) - - # Sort in descending fitness order. - species_data.sort(key=lambda x: x[1].fitness, reverse=True) - - result = [] - for idx, (sid, s) in enumerate(species_data): - - if idx < self.species_elitism: # elitism species never stagnate! - is_stagnant = False - else: - stagnant_time = generation - s.last_improved - is_stagnant = stagnant_time > self.max_stagnation - - result.append((sid, s, is_stagnant)) - return result - - def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: - """ - :param fitnesses: - :param generation: - :return: crossover_pair for next generation. - # int -> idx in the pop_nodes, pop_connections of elitism - # (int, int) -> the father and mother idx to be crossover - """ - # Filter out stagnated species, collect the set of non-stagnated - # species members, and compute their average adjusted fitness. - # The average adjusted fitness scheme (normalized to the interval - # [0, 1]) allows the use of negative fitness values without - # interfering with the shared fitness scheme. - - min_fitness = np.inf - max_fitness = -np.inf - - remaining_species = [] - for stag_sid, stag_s, stagnant in self.__stagnation(generation): - if not stagnant: - min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses)) - max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses)) - remaining_species.append(stag_s) - - # No species left. - assert remaining_species - - - # TODO: Too complex! - # Compute each species' member size in the next generation. - - # Do not allow the fitness range to be zero, as we divide by it below. - # TODO: The ``1.0`` below is rather arbitrary, and should be configurable. - fitness_range = max(1.0, max_fitness - min_fitness) - for afs in remaining_species: - # Compute adjusted fitness. - msf = afs.fitness - af = (msf - min_fitness) / fitness_range # make adjusted fitness in [0, 1] - afs.adjusted_fitness = af - adjusted_fitnesses = [s.adjusted_fitness for s in remaining_species] - previous_sizes = [len(s.members) for s in remaining_species] - min_species_size = max(self.min_species_size, self.genome_elitism) - spawn_amounts = compute_spawn(adjusted_fitnesses, previous_sizes, self.pop_size, min_species_size) - assert sum(spawn_amounts) == self.pop_size - - # generate new population and speciate - self.species = {} - # int -> idx in the pop_nodes, pop_connections of elitism - # (int, int) -> the father and mother idx to be crossover - - part1, part2, elite_mask = [], [], [] - - for spawn, s in zip(spawn_amounts, remaining_species): - assert spawn >= self.genome_elitism - - # retain remain species to next generation - old_members, member_fitnesses = s.members, s.member_fitnesses - s.members = [] - self.species[s.key] = s - - # add elitism genomes to next generation - sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses) - if self.genome_elitism > 0: - for m in sorted_members[:self.genome_elitism]: - part1.append(m) - part2.append(m) - elite_mask.append(True) - spawn -= 1 - - if spawn <= 0: - continue - - # add genome to be crossover to next generation - repro_cutoff = int(np.ceil(self.survival_threshold * len(sorted_members))) - repro_cutoff = max(repro_cutoff, 2) - # only use good genomes to crossover - sorted_members = sorted_members[:repro_cutoff] - - # TODO: Genome with higher fitness should be more likely to be selected? - list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True) - part1.extend(sorted_members[list_idx1]) - part2.extend(sorted_members[list_idx2]) - elite_mask.extend([False] * spawn) - - part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2] - is_part1_win = part1_fitness >= part2_fitness - winner_part = np.where(is_part1_win, part1, part2) - loser_part = np.where(is_part1_win, part2, part1) - - return winner_part, loser_part, np.array(elite_mask) - - def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation): - for idx, key in enumerate(species_keys): - if key == I_INT: - continue - - members = np.where(idx2specie == key)[0] - assert len(members) > 0 - - if key not in self.species: - # the new specie created in this generation - s = Species(key, generation) - self.species[key] = s - - self.species[key].update((center_nodes[idx], center_cons[idx]), members) - - def ask(self, fitnesses, generation, symbols): - self.__update_species_fitnesses(fitnesses) - - winner, loser, elite_mask = self.__reproduce(fitnesses, generation) - - center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan) - center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan) - species_keys = np.full((symbols['S'], ), I_INT) - - for idx, (key, specie) in enumerate(self.species.items()): - center_nodes[idx], center_cons[idx] = specie.representative - species_keys[idx] = key - - next_new_specie_key = max(self.species.keys()) + 1 - - return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key - - -def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): - """ - Code from neat-python, the only modification is to fix the population size for each generation. - Compute the proper number of offspring per species (proportional to fitness). - """ - af_sum = sum(adjusted_fitness) - - spawn_amounts = [] - for af, ps in zip(adjusted_fitness, previous_sizes): - if af_sum > 0: - s = max(min_species_size, af / af_sum * pop_size) - else: - s = min_species_size - - d = (s - ps) * 0.5 - c = int(round(d)) - spawn = ps - if abs(c) > 0: - spawn += c - elif d > 0: - spawn += 1 - elif d < 0: - spawn -= 1 - - spawn_amounts.append(spawn) - - # Normalize the spawn amounts so that the next generation is roughly - # the population size requested by the user. - total_spawn = sum(spawn_amounts) - norm = pop_size / total_spawn - spawn_amounts = [max(min_species_size, int(round(n * norm))) for n in spawn_amounts] - - # for batch parallelization, pop size must be a fixed value. - total_amounts = sum(spawn_amounts) - spawn_amounts[0] += pop_size - total_amounts - assert sum(spawn_amounts) == pop_size, "Population size is not stable." - - return spawn_amounts - - -def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \ - -> Tuple[NDArray, NDArray]: - sorted_idx = np.argsort(fitnesses)[::-1] - return members[sorted_idx], fitnesses[sorted_idx] diff --git a/configs/default_config.ini b/configs/default_config.ini index 29810c1..25c3ac3 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -1,7 +1,7 @@ [basic] num_inputs = 2 num_outputs = 1 -init_maximum_nodes = 200 +init_maximum_nodes = 50 init_maximum_connections = 200 init_maximum_species = 10 expand_coe = 1.5 @@ -11,9 +11,9 @@ batch_size = 4 [population] fitness_threshold = 100000 -generation_limit = 100 +generation_limit = 1000 fitness_criterion = "max" -pop_size = 150 +pop_size = 1500 [genome] compatibility_disjoint = 1.0 diff --git a/examples/evox_test.py b/examples/evox_test.py deleted file mode 100644 index 54150c1..0000000 --- a/examples/evox_test.py +++ /dev/null @@ -1,26 +0,0 @@ -import jax -from jax import numpy as jnp - -from evox import algorithms, problems, pipelines -from evox.monitors import StdSOMonitor - -monitor = StdSOMonitor() - -pso = algorithms.PSO( - lb=jnp.full(shape=(2,), fill_value=-32), - ub=jnp.full(shape=(2,), fill_value=32), - pop_size=100, -) - -ackley = problems.classic.Ackley() - -pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit) - -key = jax.random.PRNGKey(42) -state = pipeline.init(key) - -# run the pipeline for 100 steps -for i in range(100): - state = pipeline.step(state) - -print(monitor.get_min_fitness()) \ No newline at end of file diff --git a/examples/jit_xor.py b/examples/jit_xor.py deleted file mode 100644 index bd92570..0000000 --- a/examples/jit_xor.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np - -from configs import Configer -from jit_pipeline import Pipeline - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return np.array(fitnesses) # returns a list - - -def main(): - config = Configer.load_config("xor.ini") - pipeline = Pipeline(config, seed=6) - nodes, cons = pipeline.auto_run(evaluate) - print(nodes, cons) - - -if __name__ == '__main__': - main() diff --git a/examples/xor.py b/examples/xor.py index f91583a..0bded6d 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,7 +1,8 @@ import numpy as np from configs import Configer -from pipeline import Pipeline +from algorithms.neat.pipeline import Pipeline + xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -21,7 +22,8 @@ def main(): config = Configer.load_config("xor.ini") pipeline = Pipeline(config, seed=6) nodes, cons = pipeline.auto_run(evaluate) - print(nodes, cons) + print(nodes) + print(cons) if __name__ == '__main__': diff --git a/function_factory.py b/function_factory.py deleted file mode 100644 index c05cd2d..0000000 --- a/function_factory.py +++ /dev/null @@ -1,132 +0,0 @@ -import numpy as np -from jax import jit, vmap - -from algorithms.neat import create_forward, topological_sort, \ - unflatten_connections, create_next_generation_then_speciate - - -def hash_symbols(symbols): - return symbols['P'], symbols['N'], symbols['C'], symbols['S'] - - -class FunctionFactory: - """ - Creates and compiles functions used in the NEAT pipeline. - """ - - def __init__(self, config, jit_config): - self.config = config - self.jit_config = jit_config - - self.func_dict = {} - self.function_info = {} - - # (inputs_nums, ) -> (outputs_nums, ) - forward = create_forward(config) # input size (inputs_nums, ) - - # (batch_size, inputs_nums) -> (batch_size, outputs_nums) - batch_forward = vmap(forward, in_axes=(0, None, None, None)) - - # (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) - pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0)) - - # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) - common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) - - self.function_info = { - "pop_unflatten_connections": { - 'func': vmap(unflatten_connections), - 'lowers': [ - {'shape': ('P', 'N', 5), 'type': np.float32}, - {'shape': ('P', 'C', 4), 'type': np.float32} - ] - }, - - "pop_topological_sort": { - 'func': vmap(topological_sort), - 'lowers': [ - {'shape': ('P', 'N', 5), 'type': np.float32}, - {'shape': ('P', 2, 'N', 'N'), 'type': np.float32}, - ] - }, - - "batch_forward": { - 'func': batch_forward, - 'lowers': [ - {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, - {'shape': ('N',), 'type': np.int32}, - {'shape': ('N', 5), 'type': np.float32}, - {'shape': (2, 'N', 'N'), 'type': np.float32} - ] - }, - - "pop_batch_forward": { - 'func': pop_batch_forward, - 'lowers': [ - {'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32}, - {'shape': ('P', 'N'), 'type': np.int32}, - {'shape': ('P', 'N', 5), 'type': np.float32}, - {'shape': ('P', 2, 'N', 'N'), 'type': np.float32} - ] - }, - - 'common_forward': { - 'func': common_forward, - 'lowers': [ - {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, - {'shape': ('P', 'N'), 'type': np.int32}, - {'shape': ('P', 'N', 5), 'type': np.float32}, - {'shape': ('P', 2, 'N', 'N'), 'type': np.float32} - ] - }, - - 'create_next_generation_then_speciate': { - 'func': create_next_generation_then_speciate, - 'lowers': [ - {'shape': (2,), 'type': np.uint32}, # rand_key - {'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes - {'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons - {'shape': ('P',), 'type': np.int32}, # winner - {'shape': ('P',), 'type': np.int32}, # loser - {'shape': ('P',), 'type': bool}, # elite_mask - {'shape': ('P',), 'type': np.int32}, # new_node_keys - {'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes - {'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons - {'shape': ('S',), 'type': np.int32}, # species_keys - {'shape': (), 'type': np.int32}, # new_species_key_start - "jit_config" - ] - } - } - - def get(self, name, symbols): - if (name, hash_symbols(symbols)) not in self.func_dict: - self.compile(name, symbols) - return self.func_dict[name, hash_symbols(symbols)] - - def compile(self, name, symbols): - # prepare function prototype - func = self.function_info[name]['func'] - - # prepare lower operands - lowers_operands = [] - for lower in self.function_info[name]['lowers']: - if isinstance(lower, dict): - shape = list(lower['shape']) - for i, s in enumerate(shape): - if s in symbols: - shape[i] = symbols[s] - assert isinstance(shape[i], int) - lowers_operands.append(np.zeros(shape, dtype=lower['type'])) - - elif lower == "jit_config": - lowers_operands.append(self.jit_config) - - else: - raise ValueError("Invalid lower operand") - - # compile - compiled_func = jit(func).lower(*lowers_operands).compile() - - # save for reuse - self.func_dict[name, hash_symbols(symbols)] = compiled_func diff --git a/pipeline.py b/pipeline.py deleted file mode 100644 index 2b6bf37..0000000 --- a/pipeline.py +++ /dev/null @@ -1,189 +0,0 @@ -import time -from typing import Union, Callable - -import numpy as np -import jax - -from configs import Configer -from function_factory import FunctionFactory -from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController - - -class Pipeline: - """ - Neat algorithm pipeline. - """ - - def __init__(self, config, function_factory=None, seed=42): - self.randkey = jax.random.PRNGKey(seed) - np.random.seed(seed) - - self.config = config # global config - self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions - self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config) - - self.symbols = { - 'P': self.config['pop_size'], - 'N': self.config['init_maximum_nodes'], - 'C': self.config['init_maximum_connections'], - 'S': self.config['init_maximum_species'], - } - - self.generation = 0 - self.best_genome = None - - self.species_controller = SpeciesController(self.config) - self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config) - self.species_controller.init_speciate(self.pop_nodes, self.pop_cons) - - self.best_fitness = float('-inf') - self.best_genome = None - self.generation_timestamp = time.time() - - self.evaluate_time = 0 - print(self.config) - - def ask(self): - """ - Creates a function that receives a genome and returns a forward function. - There are 3 types of config['forward_way']: {'single', 'pop', 'common'} - - single: - Create pop_size number of forward functions. - Each function receive (batch_size, input_size) and returns (batch_size, output_size) - e.g. RL task - - pop: - Create a single forward function, which use only once calculation for the population. - The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size) - - common: - Special case of pop. The population has the same inputs. - The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size) - e.g. numerical regression; Hyper-NEAT - - """ - u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons) - pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons) - - if self.config['forward_way'] == 'single': - forward_funcs = [] - for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons): - func = lambda x: self.get_func('forward')(x, seq, nodes, cons) - forward_funcs.append(func) - return forward_funcs - - elif self.config['forward_way'] == 'pop': - func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) - return func - - elif self.config['forward_way'] == 'common': - func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons) - return func - - else: - raise NotImplementedError - - def tell(self, fitnesses): - self.generation += 1 - - winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \ - self.species_controller.ask(fitnesses, self.generation, self.symbols) - - # node keys to be used in the mutation process - new_node_keys = np.arange(self.generation * self.config['pop_size'], - self.generation * self.config['pop_size'] + self.config['pop_size']) - - # create the next generation and then speciate the population - self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ - self.get_func('create_next_generation_then_speciate') \ - (self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes, - center_cons, species_keys, species_key_start, self.jit_config) - - # carry data to cpu - self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \ - jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys]) - - self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation) - - # expand the population if needed - self.expand() - - # update randkey - self.randkey = jax.random.split(self.randkey)[0] - - def expand(self): - """ - Expand the population if needed. - when the maximum node number >= N or the maximum connection number of >= C - the population will expand - """ - - # analysis nodes - pop_node_keys = self.pop_nodes[:, :, 0] - pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1) - max_node_size = np.max(pop_node_sizes) - - # analysis connections - pop_con_keys = self.pop_cons[:, :, 0] - pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) - max_con_size = np.max(pop_node_sizes) - - # expand if needed - if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']: - if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']: - self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe']) - print(f"pre node expand to {self.symbols['N']}!") - - if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']: - self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe']) - print(f"pre connection expand to {self.symbols['C']}!") - - self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C']) - # don't forget to expand representation genome in species - for s in self.species_controller.species.values(): - s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C']) - - def get_func(self, name): - return self.function_factory.get(name, self.symbols) - - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): - for _ in range(self.config['generation_limit']): - forward_func = self.ask() - - tic = time.time() - fitnesses = fitness_func(forward_func) - self.evaluate_time += time.time() - tic - - assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" - - if analysis is not None: - if analysis == "default": - self.default_analysis(fitnesses) - else: - assert callable(analysis), f"Callable is needed hereπŸ˜…πŸ˜…πŸ˜… A {analysis}?" - analysis(fitnesses) - - if max(fitnesses) >= self.config['fitness_threshold']: - print("Fitness limit reached!") - return self.best_genome - - self.tell(fitnesses) - print("Generation limit reached!") - return self.best_genome - - def default_analysis(self, fitnesses): - max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) - species_sizes = [len(s.members) for s in self.species_controller.species.values()] - - new_timestamp = time.time() - cost_time = new_timestamp - self.generation_timestamp - self.generation_timestamp = new_timestamp - - max_idx = np.argmax(fitnesses) - if fitnesses[max_idx] > self.best_fitness: - self.best_fitness = fitnesses[max_idx] - self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) - - print(f"Generation: {self.generation}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")