finish jit-able speciate function
next time i'll create a new branch
This commit is contained in:
@@ -76,6 +76,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
|||||||
return fetch_first(mask, default)
|
return fetch_first(mask, default)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||||
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||||
|
min_idx = jnp.argmin(masked_arr)
|
||||||
|
return min_idx
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
a = jnp.array([1, 2, 3, 4, 5])
|
a = jnp.array([1, 2, 3, 4, 5])
|
||||||
|
|||||||
@@ -1,4 +1,109 @@
|
|||||||
from jax import jit
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit, vmap
|
||||||
|
|
||||||
|
from jax import Array
|
||||||
|
|
||||||
|
from .genome import distance
|
||||||
|
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def jitable_speciate():
|
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
||||||
pass
|
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
args:
|
||||||
|
pop_nodes: (pop_size, N, 5)
|
||||||
|
pop_cons: (pop_size, C, 4)
|
||||||
|
spe_center_nodes: (species_size, N, 5)
|
||||||
|
spe_center_cons: (species_size, C, 4)
|
||||||
|
"""
|
||||||
|
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
||||||
|
|
||||||
|
# prepare distance functions
|
||||||
|
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
||||||
|
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||||
|
s2p_distance_func = vmap(
|
||||||
|
o2p_distance_func, in_axes=(0, 0, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||||
|
|
||||||
|
# the distance between each species' center and each genome in population
|
||||||
|
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
||||||
|
|
||||||
|
def continue_execute_while(carry):
|
||||||
|
i, i2s, scn, scc = carry
|
||||||
|
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||||
|
not_reach_species_upper_bounds = i < species_size
|
||||||
|
return not_all_assigned & not_reach_species_upper_bounds
|
||||||
|
|
||||||
|
def deal_with_each_center_genome(carry):
|
||||||
|
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||||
|
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
||||||
|
|
||||||
|
i2s, scn, scc = jax.lax.cond(
|
||||||
|
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
||||||
|
create_new_specie, # if not valid, create a new specie
|
||||||
|
update_exist_specie, # if valid, update the specie
|
||||||
|
(i, i2s, scn, scc)
|
||||||
|
)
|
||||||
|
|
||||||
|
return i + 1, i2s, scn, scc
|
||||||
|
|
||||||
|
def create_new_specie(carry):
|
||||||
|
i, i2s, scn, scc = carry
|
||||||
|
# pick the first one who has not been assigned to any species
|
||||||
|
idx = fetch_first(i2s == I_INT)
|
||||||
|
|
||||||
|
# assign it to new specie
|
||||||
|
i2s = i2s.at[idx].set(i)
|
||||||
|
|
||||||
|
# update center genomes
|
||||||
|
scn = scn.at[i].set(pop_nodes[idx])
|
||||||
|
scc = scc.at[i].set(pop_cons[idx])
|
||||||
|
|
||||||
|
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||||
|
return i2s, scn, scc
|
||||||
|
|
||||||
|
def update_exist_specie(carry):
|
||||||
|
i, i2s, scn, scc = carry
|
||||||
|
|
||||||
|
# find new center
|
||||||
|
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||||
|
|
||||||
|
# update new center
|
||||||
|
i2s = i2s.at[idx].set(i)
|
||||||
|
|
||||||
|
# update center genomes
|
||||||
|
scn = scn.at[i].set(pop_nodes[idx])
|
||||||
|
scc = scc.at[i].set(pop_cons[idx])
|
||||||
|
|
||||||
|
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||||
|
return i2s, scn, scc
|
||||||
|
|
||||||
|
def speciate_by_threshold(carry):
|
||||||
|
i, i2s, scn, scc = carry
|
||||||
|
# distance between such center genome and ppo genomes
|
||||||
|
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
||||||
|
close_enough_mask = o2p_distance < compatibility_threshold
|
||||||
|
|
||||||
|
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||||
|
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
|
||||||
|
return i2s, scn, scc
|
||||||
|
|
||||||
|
# update idx2specie
|
||||||
|
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
|
||||||
|
continue_execute_while,
|
||||||
|
deal_with_each_center_genome,
|
||||||
|
(0, idx2specie, spe_center_nodes, spe_center_cons)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||||
|
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||||
|
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
|
||||||
|
|
||||||
|
return idx2specie, spe_center_nodes, spe_center_cons
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ if __name__ == '__main__':
|
|||||||
new_node_idx += len(pop_nodes)
|
new_node_idx += len(pop_nodes)
|
||||||
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||||
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||||
# for i in range(len(pop_nodes)):
|
|
||||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
|
||||||
idx1 = np.random.permutation(len(pop_nodes))
|
idx1 = np.random.permutation(len(pop_nodes))
|
||||||
idx2 = np.random.permutation(len(pop_nodes))
|
idx2 = np.random.permutation(len(pop_nodes))
|
||||||
|
|
||||||
@@ -32,13 +30,6 @@ if __name__ == '__main__':
|
|||||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||||
|
|
||||||
# for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)):
|
|
||||||
# n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
|
||||||
# try:
|
|
||||||
# check_array_valid(n, c, input_idx, output_idx)
|
|
||||||
# except AssertionError as e:
|
|
||||||
# crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
|
||||||
|
|
||||||
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
||||||
|
|
||||||
for i in range(len(pop_nodes)):
|
for i in range(len(pop_nodes)):
|
||||||
|
|||||||
@@ -3,56 +3,45 @@ import jax.numpy as jnp
|
|||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
from time_utils import using_cprofile
|
from time_utils import using_cprofile
|
||||||
from time import time
|
from time import time
|
||||||
|
#
|
||||||
@jit
|
@jit
|
||||||
def fx(x, y):
|
def fx(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # @jit
|
||||||
|
# def fy(z):
|
||||||
|
# z1, z2 = z, z + 1
|
||||||
|
# vmap_fx = vmap(fx)
|
||||||
|
# return vmap_fx(z1, z2)
|
||||||
|
#
|
||||||
# @jit
|
# @jit
|
||||||
def fy(z):
|
# def test_while(num, init_val):
|
||||||
z1, z2 = z, z + 1
|
# def cond_fun(carry):
|
||||||
vmap_fx = vmap(fx)
|
# i, cumsum = carry
|
||||||
return vmap_fx(z1, z2)
|
# return i < num
|
||||||
|
#
|
||||||
@jit
|
# def body_fun(carry):
|
||||||
def test_while(num, init_val):
|
# i, cumsum = carry
|
||||||
def cond_fun(carry):
|
# cumsum += i
|
||||||
i, cumsum = carry
|
# return i + 1, cumsum
|
||||||
return i < num
|
#
|
||||||
|
# return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
|
||||||
def body_fun(carry):
|
|
||||||
i, cumsum = carry
|
|
||||||
cumsum += i
|
|
||||||
return i + 1, cumsum
|
|
||||||
|
|
||||||
return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@using_cprofile
|
|
||||||
|
# @using_cprofile
|
||||||
def main():
|
def main():
|
||||||
z = jnp.zeros((100000, ))
|
vmap_f = vmap(fx, in_axes=(None, 0))
|
||||||
|
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
|
||||||
|
a = jnp.array([20,10,30])
|
||||||
|
b = jnp.array([6, 5, 4])
|
||||||
|
res = vmap_vmap_f(a, b)
|
||||||
|
print(res)
|
||||||
|
print(jnp.argmin(res, axis=1))
|
||||||
|
|
||||||
num = 100
|
|
||||||
|
|
||||||
nums = jnp.arange(num) * 10
|
|
||||||
|
|
||||||
f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile()
|
|
||||||
def test_time(*args):
|
|
||||||
return f(*args)
|
|
||||||
|
|
||||||
print(test_time(nums, z))
|
|
||||||
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# for i in range(10):
|
|
||||||
# num = 10 ** i
|
|
||||||
# st = time()
|
|
||||||
# res = test_time(num, z)
|
|
||||||
# print(res)
|
|
||||||
# t = time() - st
|
|
||||||
# print(f"num: {num}, time: {t}")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
67
examples/jitable_speciate_t.py
Normal file
67
examples/jitable_speciate_t.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from algorithms.neat.function_factory import FunctionFactory
|
||||||
|
from algorithms.neat.genome.debug.tools import check_array_valid
|
||||||
|
from utils import Configer
|
||||||
|
from algorithms.neat.jitable_speciate import jitable_speciate
|
||||||
|
from algorithms.neat.genome.crossover import crossover
|
||||||
|
from algorithms.neat.genome.utils import I_INT
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Configer.load_config()
|
||||||
|
function_factory = FunctionFactory(config, debug=True)
|
||||||
|
initialize_func = function_factory.create_initialize()
|
||||||
|
|
||||||
|
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
|
||||||
|
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
|
||||||
|
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
|
||||||
|
|
||||||
|
N, C, species_size = function_factory.init_N, function_factory.init_C, 20
|
||||||
|
spe_center_nodes = np.full((species_size, N, 5), np.nan)
|
||||||
|
spe_center_connections = np.full((species_size, C, 4), np.nan)
|
||||||
|
spe_center_nodes[0] = pop_nodes[0]
|
||||||
|
spe_center_connections[0] = pop_connections[0]
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
new_node_idx = 100
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_time = time()
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
mutate_keys = jax.random.split(subkey, len(pop_nodes))
|
||||||
|
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
|
||||||
|
new_node_idx += len(pop_nodes)
|
||||||
|
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||||
|
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||||
|
# for i in range(len(pop_nodes)):
|
||||||
|
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||||
|
idx1 = np.random.permutation(len(pop_nodes))
|
||||||
|
idx2 = np.random.permutation(len(pop_nodes))
|
||||||
|
|
||||||
|
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
||||||
|
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||||
|
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||||
|
|
||||||
|
# for i in range(len(pop_nodes)):
|
||||||
|
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||||
|
|
||||||
|
#speciate next generation
|
||||||
|
|
||||||
|
idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
|
||||||
|
compatibility_threshold=2.5)
|
||||||
|
|
||||||
|
idx2specie = np.array(idx2specie)
|
||||||
|
spe_dict = {}
|
||||||
|
for i in range(len(idx2specie)):
|
||||||
|
spe_idx = idx2specie[i]
|
||||||
|
if spe_idx not in spe_dict:
|
||||||
|
spe_dict[spe_idx] = 1
|
||||||
|
else:
|
||||||
|
spe_dict[spe_idx] += 1
|
||||||
|
|
||||||
|
print(spe_dict)
|
||||||
|
assert np.all(idx2specie != I_INT)
|
||||||
|
print(time() - start_time)
|
||||||
|
# print(idx2specie)
|
||||||
@@ -3,8 +3,8 @@
|
|||||||
"num_inputs": 2,
|
"num_inputs": 2,
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"problem_batch": 4,
|
"problem_batch": 4,
|
||||||
"init_maximum_nodes": 10,
|
"init_maximum_nodes": 50,
|
||||||
"init_maximum_connections": 10,
|
"init_maximum_connections": 50,
|
||||||
"expands_coe": 2,
|
"expands_coe": 2,
|
||||||
"pre_compile_times": 3,
|
"pre_compile_times": 3,
|
||||||
"forward_way": "pop_batch"
|
"forward_way": "pop_batch"
|
||||||
|
|||||||
Reference in New Issue
Block a user