This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -4,9 +4,10 @@ from jax import jit, vmap
from time_utils import using_cprofile
from time import time
#
import numpy as np
@jit
def fx(x, y):
return x + y
def fx(x):
return jnp.arange(x, x + 10)
#
#
# # @jit
@@ -33,13 +34,15 @@ def fx(x, y):
# @using_cprofile
def main():
vmap_f = vmap(fx, in_axes=(None, 0))
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
a = jnp.array([20,10,30])
b = jnp.array([6, 5, 4])
res = vmap_vmap_f(a, b)
print(res)
print(jnp.argmin(res, axis=1))
print(fx(1))
# vmap_f = vmap(fx, in_axes=(None, 0))
# vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
# a = jnp.array([20,10,30])
# b = jnp.array([6, 5, 4])
# res = vmap_vmap_f(a, b)
# print(res)
# print(jnp.argmin(res, axis=1))

View File

@@ -4,7 +4,7 @@ import numpy as np
from algorithms.neat.function_factory import FunctionFactory
from algorithms.neat.genome.debug.tools import check_array_valid
from utils import Configer
from algorithms.neat.jitable_speciate import jitable_speciate
from algorithms.neat.population import speciate
from algorithms.neat.genome.crossover import crossover
from algorithms.neat.genome.utils import I_INT
from time import time
@@ -23,7 +23,9 @@ if __name__ == '__main__':
spe_center_connections = np.full((species_size, C, 4), np.nan)
spe_center_nodes[0] = pop_nodes[0]
spe_center_connections[0] = pop_connections[0]
spe_keys = np.full((species_size,), I_INT)
spe_keys[0] = 0
new_spe_key = 1
key = jax.random.PRNGKey(0)
new_node_idx = 100
@@ -43,25 +45,31 @@ if __name__ == '__main__':
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
crossover_keys = jax.random.split(subkey, len(pop_nodes))
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
# for i in range(len(pop_nodes)):
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
#speciate next generation
idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
compatibility_threshold=2.5)
idx2specie, spe_center_nodes, spe_center_cons, spe_keys, new_spe_key = speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
spe_keys, new_spe_key,
compatibility_threshold=3)
idx2specie = np.array(idx2specie)
spe_dict = {}
for i in range(len(idx2specie)):
spe_idx = idx2specie[i]
if spe_idx not in spe_dict:
spe_dict[spe_idx] = 1
else:
spe_dict[spe_idx] += 1
print(spe_keys, new_spe_key)
print(spe_dict)
assert np.all(idx2specie != I_INT)
#
# idx2specie = np.array(idx2specie)
# spe_dict = {}
# for i in range(len(idx2specie)):
# spe_idx = idx2specie[i]
# if spe_idx not in spe_dict:
# spe_dict[spe_idx] = 1
# else:
# spe_dict[spe_idx] += 1
#
# print(spe_dict)
# assert np.all(idx2specie != I_INT)
print(time() - start_time)
# print(idx2specie)

View File

@@ -12,7 +12,7 @@ def main():
config = Configer.load_config()
problem = Xor()
problem.refactor_config(config)
pipeline = Pipeline(config, seed=0)
pipeline = Pipeline(config, seed=1)
pipeline.auto_run(problem.evaluate)