From 9ae9d9dfdce4329c9534fd064351d1b7e11b4552 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 13 Aug 2023 12:30:50 +0800 Subject: [PATCH] add aaai exp --- aaai_exp.py | 16 ++++++++++------ examples/a.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 examples/a.py diff --git a/aaai_exp.py b/aaai_exp.py index 9913716..915da53 100644 --- a/aaai_exp.py +++ b/aaai_exp.py @@ -21,16 +21,16 @@ def conf_cartpole(): pop_size=10000 ), neat=NeatConfig( - inputs=4, - outputs=2, + inputs=3, + outputs=1, ), gene=NormalGeneConfig( activation_default=Act.tanh, activation_options=(Act.tanh,), ), problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + env_name='Pendulum-v1', + output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] ) ) @@ -100,6 +100,9 @@ def main(): alg_state = algorithm.setup(alg_key) for i in range(conf.basic.generation_limit): + + total_tic = time() + pro_key, _ = jax.random.split(pro_key) fitnesses, a1, env_time, forward_time = batch_evaluate( @@ -118,9 +121,10 @@ def main(): a2 = time() - alg_tic alg_time = a1 + a2 + total_time = time() - total_tic - print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, " - f"max_fitness: {np.max(fitnesses):.2f}") + print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, " + f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}") if __name__ == '__main__': diff --git a/examples/a.py b/examples/a.py new file mode 100644 index 0000000..d100c6e --- /dev/null +++ b/examples/a.py @@ -0,0 +1,33 @@ +import jax.random +import numpy as np +import jax.numpy as jnp +import time + + +def random_array(key): + return jax.random.normal(key, (1000,)) + +def random_array_np(): + return np.random.normal(size=(1000,)) + + +def t_jax(): + key = jax.random.PRNGKey(42) + max_li = [] + tic = time.time() + for _ in range(100): + key, sub_key = jax.random.split(key) + array = random_array(sub_key) + array = jax.device_get(array) + max_li.append(max(array)) + print(max_li, time.time() - tic) + +def t_np(): + max_li = [] + tic = time.time() + for _ in range(100): + max_li.append(max(random_array_np())) + print(max_li, time.time() - tic) + +if __name__ == '__main__': + t_np() \ No newline at end of file