add test for aaai
This commit is contained in:
127
aaai_exp.py
Normal file
127
aaai_exp.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Callable
|
||||
from time import time
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap, jit
|
||||
import gymnax
|
||||
import numpy as np
|
||||
|
||||
from config import *
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def conf_cartpole():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
generation_limit=150,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=4,
|
||||
outputs=2,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def batch_evaluate(
|
||||
key,
|
||||
alg_state,
|
||||
genomes,
|
||||
env_params,
|
||||
batch_transform: Callable,
|
||||
batch_act: Callable,
|
||||
batch_reset: Callable,
|
||||
batch_step: Callable,
|
||||
):
|
||||
alg_time, env_time, forward_time = 0, 0, 0
|
||||
pop_size = genomes.nodes.shape[0]
|
||||
|
||||
alg_tic = time()
|
||||
genomes_transform = batch_transform(alg_state, genomes)
|
||||
alg_time += time() - alg_tic
|
||||
|
||||
reset_keys = jax.random.split(key, pop_size)
|
||||
observations, states = batch_reset(reset_keys, env_params)
|
||||
|
||||
done = np.zeros(pop_size, dtype=bool)
|
||||
fitnesses = np.zeros(pop_size)
|
||||
|
||||
while not np.all(done):
|
||||
key, _ = jax.random.split(key)
|
||||
vmap_keys = jax.random.split(key, pop_size)
|
||||
|
||||
forward_tic = time()
|
||||
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
|
||||
forward_time += time() - forward_tic
|
||||
|
||||
env_tic = time()
|
||||
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
|
||||
reward, current_done = jax.device_get([reward, current_done])
|
||||
env_time += time() - env_tic
|
||||
|
||||
fitnesses += reward * np.logical_not(done)
|
||||
done = np.logical_or(done, current_done)
|
||||
|
||||
return fitnesses, alg_time, env_time, forward_time
|
||||
|
||||
|
||||
def main():
|
||||
conf = conf_cartpole()
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
|
||||
def act(state, inputs, genome):
|
||||
res = algorithm.act(state, inputs, genome)
|
||||
return conf.problem.output_transform(res)
|
||||
|
||||
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
|
||||
# (state, obs, genome_transform) -> action
|
||||
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
|
||||
|
||||
env, env_params = gymnax.make(conf.problem.env_name)
|
||||
# (seed, params) -> (ini_obs, ini_state)
|
||||
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
|
||||
# (seed, state, action, params) -> (obs, state, reward, done, info)
|
||||
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
|
||||
|
||||
key = jax.random.PRNGKey(conf.basic.seed)
|
||||
alg_key, pro_key = jax.random.split(key)
|
||||
alg_state = algorithm.setup(alg_key)
|
||||
|
||||
for i in range(conf.basic.generation_limit):
|
||||
pro_key, _ = jax.random.split(pro_key)
|
||||
|
||||
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
||||
pro_key,
|
||||
alg_state,
|
||||
algorithm.ask(alg_state),
|
||||
env_params,
|
||||
batch_transform,
|
||||
batch_act,
|
||||
batch_reset,
|
||||
batch_step
|
||||
)
|
||||
alg_tic = time()
|
||||
alg_state = algorithm.tell(alg_state, fitnesses)
|
||||
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
|
||||
a2 = time() - alg_tic
|
||||
|
||||
alg_time = a1 + a2
|
||||
|
||||
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, "
|
||||
f"max_fitness: {np.max(fitnesses):.2f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -55,7 +55,7 @@ def example_conf3():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
fitness_target=501,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
|
||||
Reference in New Issue
Block a user