From 6006f92f3f07adbcc7a4ad93284532fff98b0ea0 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 12 May 2023 19:26:02 +0800 Subject: [PATCH] finish jit-able speciate function next time i'll create a new branch --- algorithms/neat/genome/utils.py | 6 ++ algorithms/neat/jitable_speciate.py | 111 +++++++++++++++++++++++++++- examples/function_tests.py | 9 --- examples/jax_playground.py | 69 ++++++++--------- examples/jitable_speciate_t.py | 67 +++++++++++++++++ utils/default_config.json | 4 +- 6 files changed, 212 insertions(+), 54 deletions(-) create mode 100644 examples/jitable_speciate_t.py diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 19703e3..8e662e7 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -76,6 +76,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array: return fetch_first(mask, default) +@jit +def argmin_with_mask(arr: Array, mask: Array) -> Array: + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx + if __name__ == '__main__': a = jnp.array([1, 2, 3, 4, 5]) diff --git a/algorithms/neat/jitable_speciate.py b/algorithms/neat/jitable_speciate.py index c9b16c4..8f2192e 100644 --- a/algorithms/neat/jitable_speciate.py +++ b/algorithms/neat/jitable_speciate.py @@ -1,4 +1,109 @@ -from jax import jit +from functools import partial + +import jax +import jax.numpy as jnp +from jax import jit, vmap + +from jax import Array + +from .genome import distance +from .genome.utils import I_INT, fetch_first, argmin_with_mask + + @jit -def jitable_speciate(): - pass \ No newline at end of file +def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array, + disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0 + ): + """ + args: + pop_nodes: (pop_size, N, 5) + pop_cons: (pop_size, C, 4) + spe_center_nodes: (species_size, N, 5) + spe_center_cons: (species_size, C, 4) + """ + pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0] + + # prepare distance functions + distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe) + o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) + s2p_distance_func = vmap( + o2p_distance_func, in_axes=(0, 0, None, None) + ) + + idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species + + # the distance between each species' center and each genome in population + s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons) + + def continue_execute_while(carry): + i, i2s, scn, scc = carry + not_all_assigned = ~jnp.all(i2s != I_INT) + not_reach_species_upper_bounds = i < species_size + return not_all_assigned & not_reach_species_upper_bounds + + def deal_with_each_center_genome(carry): + i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons + center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i] + + i2s, scn, scc = jax.lax.cond( + jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid + create_new_specie, # if not valid, create a new specie + update_exist_specie, # if valid, update the specie + (i, i2s, scn, scc) + ) + + return i + 1, i2s, scn, scc + + def create_new_specie(carry): + i, i2s, scn, scc = carry + # pick the first one who has not been assigned to any species + idx = fetch_first(i2s == I_INT) + + # assign it to new specie + i2s = i2s.at[idx].set(i) + + # update center genomes + scn = scn.at[i].set(pop_nodes[idx]) + scc = scc.at[i].set(pop_cons[idx]) + + i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc)) + return i2s, scn, scc + + def update_exist_specie(carry): + i, i2s, scn, scc = carry + + # find new center + idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT) + + # update new center + i2s = i2s.at[idx].set(i) + + # update center genomes + scn = scn.at[i].set(pop_nodes[idx]) + scc = scc.at[i].set(pop_cons[idx]) + + i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc)) + return i2s, scn, scc + + def speciate_by_threshold(carry): + i, i2s, scn, scc = carry + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons) + close_enough_mask = o2p_distance < compatibility_threshold + + # when it is close enough, assign it to the species, remember not to update genome has already been assigned + i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s) + return i2s, scn, scc + + # update idx2specie + _, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop( + continue_execute_while, + deal_with_each_center_genome, + (0, idx2specie, spe_center_nodes, spe_center_cons) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition seems to be only happened when the number of species is reached species upper bounds + idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie) + + return idx2specie, spe_center_nodes, spe_center_cons diff --git a/examples/function_tests.py b/examples/function_tests.py index 8de0e3b..5426837 100644 --- a/examples/function_tests.py +++ b/examples/function_tests.py @@ -23,8 +23,6 @@ if __name__ == '__main__': new_node_idx += len(pop_nodes) pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes) pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections]) - # for i in range(len(pop_nodes)): - # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) idx1 = np.random.permutation(len(pop_nodes)) idx2 = np.random.permutation(len(pop_nodes)) @@ -32,13 +30,6 @@ if __name__ == '__main__': n2, c2 = pop_nodes[idx2], pop_connections[idx2] crossover_keys = jax.random.split(subkey, len(pop_nodes)) - # for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)): - # n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2) - # try: - # check_array_valid(n, c, input_idx, output_idx) - # except AssertionError as e: - # crossover(crossover_keys[idx], zn1, zc1, zn2, zc2) - pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2) for i in range(len(pop_nodes)): diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 226532e..63df18f 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -3,56 +3,45 @@ import jax.numpy as jnp from jax import jit, vmap from time_utils import using_cprofile from time import time - +# @jit def fx(x, y): return x + y - - +# +# +# # @jit +# def fy(z): +# z1, z2 = z, z + 1 +# vmap_fx = vmap(fx) +# return vmap_fx(z1, z2) +# # @jit -def fy(z): - z1, z2 = z, z + 1 - vmap_fx = vmap(fx) - return vmap_fx(z1, z2) - -@jit -def test_while(num, init_val): - def cond_fun(carry): - i, cumsum = carry - return i < num - - def body_fun(carry): - i, cumsum = carry - cumsum += i - return i + 1, cumsum - - return jax.lax.while_loop(cond_fun, body_fun, (0, init_val)) +# def test_while(num, init_val): +# def cond_fun(carry): +# i, cumsum = carry +# return i < num +# +# def body_fun(carry): +# i, cumsum = carry +# cumsum += i +# return i + 1, cumsum +# +# return jax.lax.while_loop(cond_fun, body_fun, (0, init_val)) -@using_cprofile + +# @using_cprofile def main(): - z = jnp.zeros((100000, )) + vmap_f = vmap(fx, in_axes=(None, 0)) + vmap_vmap_f = vmap(vmap_f, in_axes=(0, None)) + a = jnp.array([20,10,30]) + b = jnp.array([6, 5, 4]) + res = vmap_vmap_f(a, b) + print(res) + print(jnp.argmin(res, axis=1)) - num = 100 - nums = jnp.arange(num) * 10 - - f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile() - def test_time(*args): - return f(*args) - - print(test_time(nums, z)) - - # - # - # for i in range(10): - # num = 10 ** i - # st = time() - # res = test_time(num, z) - # print(res) - # t = time() - st - # print(f"num: {num}, time: {t}") if __name__ == '__main__': main() diff --git a/examples/jitable_speciate_t.py b/examples/jitable_speciate_t.py new file mode 100644 index 0000000..8629665 --- /dev/null +++ b/examples/jitable_speciate_t.py @@ -0,0 +1,67 @@ +import jax +import jax.numpy as jnp +import numpy as np +from algorithms.neat.function_factory import FunctionFactory +from algorithms.neat.genome.debug.tools import check_array_valid +from utils import Configer +from algorithms.neat.jitable_speciate import jitable_speciate +from algorithms.neat.genome.crossover import crossover +from algorithms.neat.genome.utils import I_INT +from time import time + +if __name__ == '__main__': + config = Configer.load_config() + function_factory = FunctionFactory(config, debug=True) + initialize_func = function_factory.create_initialize() + + pop_nodes, pop_connections, input_idx, output_idx = initialize_func() + mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1]) + crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1]) + + N, C, species_size = function_factory.init_N, function_factory.init_C, 20 + spe_center_nodes = np.full((species_size, N, 5), np.nan) + spe_center_connections = np.full((species_size, C, 4), np.nan) + spe_center_nodes[0] = pop_nodes[0] + spe_center_connections[0] = pop_connections[0] + + key = jax.random.PRNGKey(0) + new_node_idx = 100 + + while True: + start_time = time() + key, subkey = jax.random.split(key) + mutate_keys = jax.random.split(subkey, len(pop_nodes)) + new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes)) + new_node_idx += len(pop_nodes) + pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes) + pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections]) + # for i in range(len(pop_nodes)): + # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) + idx1 = np.random.permutation(len(pop_nodes)) + idx2 = np.random.permutation(len(pop_nodes)) + + n1, c1 = pop_nodes[idx1], pop_connections[idx1] + n2, c2 = pop_nodes[idx2], pop_connections[idx2] + crossover_keys = jax.random.split(subkey, len(pop_nodes)) + + # for i in range(len(pop_nodes)): + # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) + + #speciate next generation + + idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections, + compatibility_threshold=2.5) + + idx2specie = np.array(idx2specie) + spe_dict = {} + for i in range(len(idx2specie)): + spe_idx = idx2specie[i] + if spe_idx not in spe_dict: + spe_dict[spe_idx] = 1 + else: + spe_dict[spe_idx] += 1 + + print(spe_dict) + assert np.all(idx2specie != I_INT) + print(time() - start_time) + # print(idx2specie) diff --git a/utils/default_config.json b/utils/default_config.json index 50fce9d..4ba0142 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -3,8 +3,8 @@ "num_inputs": 2, "num_outputs": 1, "problem_batch": 4, - "init_maximum_nodes": 10, - "init_maximum_connections": 10, + "init_maximum_nodes": 50, + "init_maximum_connections": 50, "expands_coe": 2, "pre_compile_times": 3, "forward_way": "pop_batch"