finish jit-able speciate function
next time i'll create a new branch
This commit is contained in:
@@ -23,8 +23,6 @@ if __name__ == '__main__':
|
||||
new_node_idx += len(pop_nodes)
|
||||
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||
# for i in range(len(pop_nodes)):
|
||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
idx1 = np.random.permutation(len(pop_nodes))
|
||||
idx2 = np.random.permutation(len(pop_nodes))
|
||||
|
||||
@@ -32,13 +30,6 @@ if __name__ == '__main__':
|
||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
|
||||
# for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)):
|
||||
# n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||
# try:
|
||||
# check_array_valid(n, c, input_idx, output_idx)
|
||||
# except AssertionError as e:
|
||||
# crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||
|
||||
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
||||
|
||||
for i in range(len(pop_nodes)):
|
||||
|
||||
@@ -3,56 +3,45 @@ import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
from time_utils import using_cprofile
|
||||
from time import time
|
||||
|
||||
#
|
||||
@jit
|
||||
def fx(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# # @jit
|
||||
# def fy(z):
|
||||
# z1, z2 = z, z + 1
|
||||
# vmap_fx = vmap(fx)
|
||||
# return vmap_fx(z1, z2)
|
||||
#
|
||||
# @jit
|
||||
def fy(z):
|
||||
z1, z2 = z, z + 1
|
||||
vmap_fx = vmap(fx)
|
||||
return vmap_fx(z1, z2)
|
||||
|
||||
@jit
|
||||
def test_while(num, init_val):
|
||||
def cond_fun(carry):
|
||||
i, cumsum = carry
|
||||
return i < num
|
||||
|
||||
def body_fun(carry):
|
||||
i, cumsum = carry
|
||||
cumsum += i
|
||||
return i + 1, cumsum
|
||||
|
||||
return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
|
||||
# def test_while(num, init_val):
|
||||
# def cond_fun(carry):
|
||||
# i, cumsum = carry
|
||||
# return i < num
|
||||
#
|
||||
# def body_fun(carry):
|
||||
# i, cumsum = carry
|
||||
# cumsum += i
|
||||
# return i + 1, cumsum
|
||||
#
|
||||
# return jax.lax.while_loop(cond_fun, body_fun, (0, init_val))
|
||||
|
||||
|
||||
|
||||
@using_cprofile
|
||||
|
||||
# @using_cprofile
|
||||
def main():
|
||||
z = jnp.zeros((100000, ))
|
||||
vmap_f = vmap(fx, in_axes=(None, 0))
|
||||
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
|
||||
a = jnp.array([20,10,30])
|
||||
b = jnp.array([6, 5, 4])
|
||||
res = vmap_vmap_f(a, b)
|
||||
print(res)
|
||||
print(jnp.argmin(res, axis=1))
|
||||
|
||||
num = 100
|
||||
|
||||
nums = jnp.arange(num) * 10
|
||||
|
||||
f = jit(vmap(test_while, in_axes=(0, None))).lower(nums, z).compile()
|
||||
def test_time(*args):
|
||||
return f(*args)
|
||||
|
||||
print(test_time(nums, z))
|
||||
|
||||
#
|
||||
#
|
||||
# for i in range(10):
|
||||
# num = 10 ** i
|
||||
# st = time()
|
||||
# res = test_time(num, z)
|
||||
# print(res)
|
||||
# t = time() - st
|
||||
# print(f"num: {num}, time: {t}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
67
examples/jitable_speciate_t.py
Normal file
67
examples/jitable_speciate_t.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from algorithms.neat.function_factory import FunctionFactory
|
||||
from algorithms.neat.genome.debug.tools import check_array_valid
|
||||
from utils import Configer
|
||||
from algorithms.neat.jitable_speciate import jitable_speciate
|
||||
from algorithms.neat.genome.crossover import crossover
|
||||
from algorithms.neat.genome.utils import I_INT
|
||||
from time import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Configer.load_config()
|
||||
function_factory = FunctionFactory(config, debug=True)
|
||||
initialize_func = function_factory.create_initialize()
|
||||
|
||||
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
|
||||
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
|
||||
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
|
||||
|
||||
N, C, species_size = function_factory.init_N, function_factory.init_C, 20
|
||||
spe_center_nodes = np.full((species_size, N, 5), np.nan)
|
||||
spe_center_connections = np.full((species_size, C, 4), np.nan)
|
||||
spe_center_nodes[0] = pop_nodes[0]
|
||||
spe_center_connections[0] = pop_connections[0]
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
new_node_idx = 100
|
||||
|
||||
while True:
|
||||
start_time = time()
|
||||
key, subkey = jax.random.split(key)
|
||||
mutate_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
|
||||
new_node_idx += len(pop_nodes)
|
||||
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||
# for i in range(len(pop_nodes)):
|
||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
idx1 = np.random.permutation(len(pop_nodes))
|
||||
idx2 = np.random.permutation(len(pop_nodes))
|
||||
|
||||
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
|
||||
# for i in range(len(pop_nodes)):
|
||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
|
||||
#speciate next generation
|
||||
|
||||
idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
|
||||
compatibility_threshold=2.5)
|
||||
|
||||
idx2specie = np.array(idx2specie)
|
||||
spe_dict = {}
|
||||
for i in range(len(idx2specie)):
|
||||
spe_idx = idx2specie[i]
|
||||
if spe_idx not in spe_dict:
|
||||
spe_dict[spe_idx] = 1
|
||||
else:
|
||||
spe_dict[spe_idx] += 1
|
||||
|
||||
print(spe_dict)
|
||||
assert np.all(idx2specie != I_INT)
|
||||
print(time() - start_time)
|
||||
# print(idx2specie)
|
||||
Reference in New Issue
Block a user