diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index c4a5f48..feaf47f 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -92,10 +92,10 @@ class SpeciesController: # First, fast match the population to previous species if previous_species_list: # exist previous species rid_list = [new_representatives[sid] for sid in previous_species_list] - res_pop_distance = [ - jax.device_get(o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)) + res_pop_distance = jax.device_get([ + o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) for rid in rid_list - ] + ]) pop_res_distance = np.stack(res_pop_distance, axis=0).T for i in range(pop_res_distance.shape[0]): @@ -118,10 +118,10 @@ class SpeciesController: if len(new_representatives) != 0: # the representatives of new species sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) - distances = [ - jax.device_get(o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])) + distances = jax.device_get([ + o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid - ] + ]) distances = np.array(distances) min_idx = np.argmin(distances) min_val = distances[min_idx] diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 2e4487f..5748071 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -8,56 +8,14 @@ from functools import partial from examples.time_utils import using_cprofile -def func(x, y): - """ - :param x: (100, ) - :param y: (100, - :return: - """ - return x * y - - -def func2(x, y, s): - """ - :param x: (100, ) - :param y: (100, - :return: - """ - if s == '123': - return 0 - else: - return x * y - - @jit -def func3(x, y): - return func2(x, y, '123') +def func(x, y): + return x + y -# @using_cprofile -def main(): - key = jax.random.PRNGKey(42) +a, b, c = jnp.array([1]), jnp.array([2]), jnp.array([3]) +li = [a, b, c] - x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,)) +cpu_li = jax.device_get(li) - jit_lower_func = jit(func).lower(1, 2).compile() - print(type(jit_lower_func)) - print(jit_lower_func.memory_analysis()) - - jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile() - print(jit_compiled_func2(x1, y1)) - - # print(jit_compiled_func2(x1, y1)) - - f = func3.lower(x1, y1).compile() - - print(f(x1, y1)) - - # print(jit_lower_func(x1, y1)) - - # x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,)) - # print(jit_lower_func(x2, y2)) - - -if __name__ == '__main__': - main() +print(cpu_li) \ No newline at end of file