Files
tensorneat-mend/aaai_exp.py
2023-08-13 12:30:50 +08:00

132 lines
3.8 KiB
Python

from typing import Callable
from time import time
import jax
from jax import numpy as jnp, vmap, jit
import gymnax
import numpy as np
from config import *
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_cartpole():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
generation_limit=150,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
def batch_evaluate(
key,
alg_state,
genomes,
env_params,
batch_transform: Callable,
batch_act: Callable,
batch_reset: Callable,
batch_step: Callable,
):
alg_time, env_time, forward_time = 0, 0, 0
pop_size = genomes.nodes.shape[0]
alg_tic = time()
genomes_transform = batch_transform(alg_state, genomes)
alg_time += time() - alg_tic
reset_keys = jax.random.split(key, pop_size)
observations, states = batch_reset(reset_keys, env_params)
done = np.zeros(pop_size, dtype=bool)
fitnesses = np.zeros(pop_size)
while not np.all(done):
key, _ = jax.random.split(key)
vmap_keys = jax.random.split(key, pop_size)
forward_tic = time()
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
forward_time += time() - forward_tic
env_tic = time()
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
reward, current_done = jax.device_get([reward, current_done])
env_time += time() - env_tic
fitnesses += reward * np.logical_not(done)
done = np.logical_or(done, current_done)
return fitnesses, alg_time, env_time, forward_time
def main():
conf = conf_cartpole()
algorithm = NEAT(conf, NormalGene)
def act(state, inputs, genome):
res = algorithm.act(state, inputs, genome)
return conf.problem.output_transform(res)
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
# (state, obs, genome_transform) -> action
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
env, env_params = gymnax.make(conf.problem.env_name)
# (seed, params) -> (ini_obs, ini_state)
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
# (seed, state, action, params) -> (obs, state, reward, done, info)
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
key = jax.random.PRNGKey(conf.basic.seed)
alg_key, pro_key = jax.random.split(key)
alg_state = algorithm.setup(alg_key)
for i in range(conf.basic.generation_limit):
total_tic = time()
pro_key, _ = jax.random.split(pro_key)
fitnesses, a1, env_time, forward_time = batch_evaluate(
pro_key,
alg_state,
algorithm.ask(alg_state),
env_params,
batch_transform,
batch_act,
batch_reset,
batch_step
)
alg_tic = time()
alg_state = algorithm.tell(alg_state, fitnesses)
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
a2 = time() - alg_tic
alg_time = a1 + a2
total_time = time() - total_tic
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
if __name__ == '__main__':
main()