add test for aaai

This commit is contained in:
wls2002
2023-08-11 19:18:11 +08:00
parent a778921892
commit 33e6ef4916
2 changed files with 128 additions and 1 deletions

127
aaai_exp.py Normal file
View File

@@ -0,0 +1,127 @@
from typing import Callable
from time import time
import jax
from jax import numpy as jnp, vmap, jit
import gymnax
import numpy as np
from config import *
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_cartpole():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
generation_limit=150,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
def batch_evaluate(
key,
alg_state,
genomes,
env_params,
batch_transform: Callable,
batch_act: Callable,
batch_reset: Callable,
batch_step: Callable,
):
alg_time, env_time, forward_time = 0, 0, 0
pop_size = genomes.nodes.shape[0]
alg_tic = time()
genomes_transform = batch_transform(alg_state, genomes)
alg_time += time() - alg_tic
reset_keys = jax.random.split(key, pop_size)
observations, states = batch_reset(reset_keys, env_params)
done = np.zeros(pop_size, dtype=bool)
fitnesses = np.zeros(pop_size)
while not np.all(done):
key, _ = jax.random.split(key)
vmap_keys = jax.random.split(key, pop_size)
forward_tic = time()
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
forward_time += time() - forward_tic
env_tic = time()
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
reward, current_done = jax.device_get([reward, current_done])
env_time += time() - env_tic
fitnesses += reward * np.logical_not(done)
done = np.logical_or(done, current_done)
return fitnesses, alg_time, env_time, forward_time
def main():
conf = conf_cartpole()
algorithm = NEAT(conf, NormalGene)
def act(state, inputs, genome):
res = algorithm.act(state, inputs, genome)
return conf.problem.output_transform(res)
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
# (state, obs, genome_transform) -> action
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
env, env_params = gymnax.make(conf.problem.env_name)
# (seed, params) -> (ini_obs, ini_state)
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
# (seed, state, action, params) -> (obs, state, reward, done, info)
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
key = jax.random.PRNGKey(conf.basic.seed)
alg_key, pro_key = jax.random.split(key)
alg_state = algorithm.setup(alg_key)
for i in range(conf.basic.generation_limit):
pro_key, _ = jax.random.split(pro_key)
fitnesses, a1, env_time, forward_time = batch_evaluate(
pro_key,
alg_state,
algorithm.ask(alg_state),
env_params,
batch_transform,
batch_act,
batch_reset,
batch_step
)
alg_tic = time()
alg_state = algorithm.tell(alg_state, fitnesses)
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
a2 = time() - alg_tic
alg_time = a1 + a2
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, "
f"max_fitness: {np.max(fitnesses):.2f}")
if __name__ == '__main__':
main()

View File

@@ -55,7 +55,7 @@ def example_conf3():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
fitness_target=501,
pop_size=10000
),
neat=NeatConfig(