add test for aaai
This commit is contained in:
127
aaai_exp.py
Normal file
127
aaai_exp.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from typing import Callable
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, vmap, jit
|
||||||
|
import gymnax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import *
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||||
|
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||||
|
|
||||||
|
|
||||||
|
def conf_cartpole():
|
||||||
|
return Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=500,
|
||||||
|
generation_limit=150,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
inputs=4,
|
||||||
|
outputs=2,
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
),
|
||||||
|
problem=GymNaxConfig(
|
||||||
|
env_name='CartPole-v1',
|
||||||
|
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_evaluate(
|
||||||
|
key,
|
||||||
|
alg_state,
|
||||||
|
genomes,
|
||||||
|
env_params,
|
||||||
|
batch_transform: Callable,
|
||||||
|
batch_act: Callable,
|
||||||
|
batch_reset: Callable,
|
||||||
|
batch_step: Callable,
|
||||||
|
):
|
||||||
|
alg_time, env_time, forward_time = 0, 0, 0
|
||||||
|
pop_size = genomes.nodes.shape[0]
|
||||||
|
|
||||||
|
alg_tic = time()
|
||||||
|
genomes_transform = batch_transform(alg_state, genomes)
|
||||||
|
alg_time += time() - alg_tic
|
||||||
|
|
||||||
|
reset_keys = jax.random.split(key, pop_size)
|
||||||
|
observations, states = batch_reset(reset_keys, env_params)
|
||||||
|
|
||||||
|
done = np.zeros(pop_size, dtype=bool)
|
||||||
|
fitnesses = np.zeros(pop_size)
|
||||||
|
|
||||||
|
while not np.all(done):
|
||||||
|
key, _ = jax.random.split(key)
|
||||||
|
vmap_keys = jax.random.split(key, pop_size)
|
||||||
|
|
||||||
|
forward_tic = time()
|
||||||
|
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
|
||||||
|
forward_time += time() - forward_tic
|
||||||
|
|
||||||
|
env_tic = time()
|
||||||
|
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
|
||||||
|
reward, current_done = jax.device_get([reward, current_done])
|
||||||
|
env_time += time() - env_tic
|
||||||
|
|
||||||
|
fitnesses += reward * np.logical_not(done)
|
||||||
|
done = np.logical_or(done, current_done)
|
||||||
|
|
||||||
|
return fitnesses, alg_time, env_time, forward_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
conf = conf_cartpole()
|
||||||
|
algorithm = NEAT(conf, NormalGene)
|
||||||
|
|
||||||
|
def act(state, inputs, genome):
|
||||||
|
res = algorithm.act(state, inputs, genome)
|
||||||
|
return conf.problem.output_transform(res)
|
||||||
|
|
||||||
|
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
|
||||||
|
# (state, obs, genome_transform) -> action
|
||||||
|
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
|
||||||
|
|
||||||
|
env, env_params = gymnax.make(conf.problem.env_name)
|
||||||
|
# (seed, params) -> (ini_obs, ini_state)
|
||||||
|
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
|
||||||
|
# (seed, state, action, params) -> (obs, state, reward, done, info)
|
||||||
|
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(conf.basic.seed)
|
||||||
|
alg_key, pro_key = jax.random.split(key)
|
||||||
|
alg_state = algorithm.setup(alg_key)
|
||||||
|
|
||||||
|
for i in range(conf.basic.generation_limit):
|
||||||
|
pro_key, _ = jax.random.split(pro_key)
|
||||||
|
|
||||||
|
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
||||||
|
pro_key,
|
||||||
|
alg_state,
|
||||||
|
algorithm.ask(alg_state),
|
||||||
|
env_params,
|
||||||
|
batch_transform,
|
||||||
|
batch_act,
|
||||||
|
batch_reset,
|
||||||
|
batch_step
|
||||||
|
)
|
||||||
|
alg_tic = time()
|
||||||
|
alg_state = algorithm.tell(alg_state, fitnesses)
|
||||||
|
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
|
||||||
|
a2 = time() - alg_tic
|
||||||
|
|
||||||
|
alg_time = a1 + a2
|
||||||
|
|
||||||
|
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, "
|
||||||
|
f"max_fitness: {np.max(fitnesses):.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -55,7 +55,7 @@ def example_conf3():
|
|||||||
return Config(
|
return Config(
|
||||||
basic=BasicConfig(
|
basic=BasicConfig(
|
||||||
seed=42,
|
seed=42,
|
||||||
fitness_target=500,
|
fitness_target=501,
|
||||||
pop_size=10000
|
pop_size=10000
|
||||||
),
|
),
|
||||||
neat=NeatConfig(
|
neat=NeatConfig(
|
||||||
|
|||||||
Reference in New Issue
Block a user