From 33e6ef4916baf691e3ac804b53b6515c97cea4c9 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 11 Aug 2023 19:18:11 +0800 Subject: [PATCH] add test for aaai --- aaai_exp.py | 127 ++++++++++++++++++++++++++++++++++++ examples/gymnax/cartpole.py | 2 +- 2 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 aaai_exp.py diff --git a/aaai_exp.py b/aaai_exp.py new file mode 100644 index 0000000..9913716 --- /dev/null +++ b/aaai_exp.py @@ -0,0 +1,127 @@ +from typing import Callable +from time import time + +import jax +from jax import numpy as jnp, vmap, jit +import gymnax +import numpy as np + +from config import * +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def conf_cartpole(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + generation_limit=150, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=2, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + ) + ) + + +def batch_evaluate( + key, + alg_state, + genomes, + env_params, + batch_transform: Callable, + batch_act: Callable, + batch_reset: Callable, + batch_step: Callable, +): + alg_time, env_time, forward_time = 0, 0, 0 + pop_size = genomes.nodes.shape[0] + + alg_tic = time() + genomes_transform = batch_transform(alg_state, genomes) + alg_time += time() - alg_tic + + reset_keys = jax.random.split(key, pop_size) + observations, states = batch_reset(reset_keys, env_params) + + done = np.zeros(pop_size, dtype=bool) + fitnesses = np.zeros(pop_size) + + while not np.all(done): + key, _ = jax.random.split(key) + vmap_keys = jax.random.split(key, pop_size) + + forward_tic = time() + actions = batch_act(alg_state, observations, genomes_transform).block_until_ready() + forward_time += time() - forward_tic + + env_tic = time() + observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params) + reward, current_done = jax.device_get([reward, current_done]) + env_time += time() - env_tic + + fitnesses += reward * np.logical_not(done) + done = np.logical_or(done, current_done) + + return fitnesses, alg_time, env_time, forward_time + + +def main(): + conf = conf_cartpole() + algorithm = NEAT(conf, NormalGene) + + def act(state, inputs, genome): + res = algorithm.act(state, inputs, genome) + return conf.problem.output_transform(res) + + batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0))) + # (state, obs, genome_transform) -> action + batch_act = jit(vmap(act, in_axes=(None, 0, 0))) + + env, env_params = gymnax.make(conf.problem.env_name) + # (seed, params) -> (ini_obs, ini_state) + batch_reset = jit(vmap(env.reset, in_axes=(0, None))) + # (seed, state, action, params) -> (obs, state, reward, done, info) + batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None))) + + key = jax.random.PRNGKey(conf.basic.seed) + alg_key, pro_key = jax.random.split(key) + alg_state = algorithm.setup(alg_key) + + for i in range(conf.basic.generation_limit): + pro_key, _ = jax.random.split(pro_key) + + fitnesses, a1, env_time, forward_time = batch_evaluate( + pro_key, + alg_state, + algorithm.ask(alg_state), + env_params, + batch_transform, + batch_act, + batch_reset, + batch_step + ) + alg_tic = time() + alg_state = algorithm.tell(alg_state, fitnesses) + alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state) + a2 = time() - alg_tic + + alg_time = a1 + a2 + + print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, " + f"max_fitness: {np.max(fitnesses):.2f}") + + +if __name__ == '__main__': + main() diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index d5b3c84..7931e8b 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -55,7 +55,7 @@ def example_conf3(): return Config( basic=BasicConfig( seed=42, - fitness_target=500, + fitness_target=501, pop_size=10000 ), neat=NeatConfig(