FAST!
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from algorithms.neat.function_factory import FunctionFactory
|
||||
from algorithms.neat.genome.debug.tools import check_array_valid
|
||||
from utils import Configer
|
||||
from algorithms.neat.jitable_speciate import jitable_speciate
|
||||
from algorithms.neat.population import speciate
|
||||
from algorithms.neat.genome.crossover import crossover
|
||||
from algorithms.neat.genome.utils import I_INT
|
||||
from time import time
|
||||
@@ -23,7 +23,9 @@ if __name__ == '__main__':
|
||||
spe_center_connections = np.full((species_size, C, 4), np.nan)
|
||||
spe_center_nodes[0] = pop_nodes[0]
|
||||
spe_center_connections[0] = pop_connections[0]
|
||||
|
||||
spe_keys = np.full((species_size,), I_INT)
|
||||
spe_keys[0] = 0
|
||||
new_spe_key = 1
|
||||
key = jax.random.PRNGKey(0)
|
||||
new_node_idx = 100
|
||||
|
||||
@@ -43,25 +45,31 @@ if __name__ == '__main__':
|
||||
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
||||
|
||||
|
||||
# for i in range(len(pop_nodes)):
|
||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
|
||||
#speciate next generation
|
||||
|
||||
idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
|
||||
compatibility_threshold=2.5)
|
||||
idx2specie, spe_center_nodes, spe_center_cons, spe_keys, new_spe_key = speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
|
||||
spe_keys, new_spe_key,
|
||||
compatibility_threshold=3)
|
||||
|
||||
idx2specie = np.array(idx2specie)
|
||||
spe_dict = {}
|
||||
for i in range(len(idx2specie)):
|
||||
spe_idx = idx2specie[i]
|
||||
if spe_idx not in spe_dict:
|
||||
spe_dict[spe_idx] = 1
|
||||
else:
|
||||
spe_dict[spe_idx] += 1
|
||||
print(spe_keys, new_spe_key)
|
||||
|
||||
print(spe_dict)
|
||||
assert np.all(idx2specie != I_INT)
|
||||
#
|
||||
# idx2specie = np.array(idx2specie)
|
||||
# spe_dict = {}
|
||||
# for i in range(len(idx2specie)):
|
||||
# spe_idx = idx2specie[i]
|
||||
# if spe_idx not in spe_dict:
|
||||
# spe_dict[spe_idx] = 1
|
||||
# else:
|
||||
# spe_dict[spe_idx] += 1
|
||||
#
|
||||
# print(spe_dict)
|
||||
# assert np.all(idx2specie != I_INT)
|
||||
print(time() - start_time)
|
||||
# print(idx2specie)
|
||||
|
||||
Reference in New Issue
Block a user