132 lines
3.8 KiB
Python
132 lines
3.8 KiB
Python
from typing import Callable
|
|
from time import time
|
|
|
|
import jax
|
|
from jax import numpy as jnp, vmap, jit
|
|
import gymnax
|
|
import numpy as np
|
|
|
|
from config import *
|
|
from algorithm import NEAT
|
|
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
|
|
|
|
def conf_cartpole():
|
|
return Config(
|
|
basic=BasicConfig(
|
|
seed=42,
|
|
fitness_target=500,
|
|
generation_limit=150,
|
|
pop_size=10000
|
|
),
|
|
neat=NeatConfig(
|
|
inputs=3,
|
|
outputs=1,
|
|
),
|
|
gene=NormalGeneConfig(
|
|
activation_default=Act.tanh,
|
|
activation_options=(Act.tanh,),
|
|
),
|
|
problem=GymNaxConfig(
|
|
env_name='Pendulum-v1',
|
|
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
|
)
|
|
)
|
|
|
|
|
|
def batch_evaluate(
|
|
key,
|
|
alg_state,
|
|
genomes,
|
|
env_params,
|
|
batch_transform: Callable,
|
|
batch_act: Callable,
|
|
batch_reset: Callable,
|
|
batch_step: Callable,
|
|
):
|
|
alg_time, env_time, forward_time = 0, 0, 0
|
|
pop_size = genomes.nodes.shape[0]
|
|
|
|
alg_tic = time()
|
|
genomes_transform = batch_transform(alg_state, genomes)
|
|
alg_time += time() - alg_tic
|
|
|
|
reset_keys = jax.random.split(key, pop_size)
|
|
observations, states = batch_reset(reset_keys, env_params)
|
|
|
|
done = np.zeros(pop_size, dtype=bool)
|
|
fitnesses = np.zeros(pop_size)
|
|
|
|
while not np.all(done):
|
|
key, _ = jax.random.split(key)
|
|
vmap_keys = jax.random.split(key, pop_size)
|
|
|
|
forward_tic = time()
|
|
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
|
|
forward_time += time() - forward_tic
|
|
|
|
env_tic = time()
|
|
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
|
|
reward, current_done = jax.device_get([reward, current_done])
|
|
env_time += time() - env_tic
|
|
|
|
fitnesses += reward * np.logical_not(done)
|
|
done = np.logical_or(done, current_done)
|
|
|
|
return fitnesses, alg_time, env_time, forward_time
|
|
|
|
|
|
def main():
|
|
conf = conf_cartpole()
|
|
algorithm = NEAT(conf, NormalGene)
|
|
|
|
def act(state, inputs, genome):
|
|
res = algorithm.act(state, inputs, genome)
|
|
return conf.problem.output_transform(res)
|
|
|
|
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
|
|
# (state, obs, genome_transform) -> action
|
|
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
|
|
|
|
env, env_params = gymnax.make(conf.problem.env_name)
|
|
# (seed, params) -> (ini_obs, ini_state)
|
|
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
|
|
# (seed, state, action, params) -> (obs, state, reward, done, info)
|
|
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
|
|
|
|
key = jax.random.PRNGKey(conf.basic.seed)
|
|
alg_key, pro_key = jax.random.split(key)
|
|
alg_state = algorithm.setup(alg_key)
|
|
|
|
for i in range(conf.basic.generation_limit):
|
|
|
|
total_tic = time()
|
|
|
|
pro_key, _ = jax.random.split(pro_key)
|
|
|
|
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
|
pro_key,
|
|
alg_state,
|
|
algorithm.ask(alg_state),
|
|
env_params,
|
|
batch_transform,
|
|
batch_act,
|
|
batch_reset,
|
|
batch_step
|
|
)
|
|
alg_tic = time()
|
|
alg_state = algorithm.tell(alg_state, fitnesses)
|
|
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
|
|
a2 = time() - alg_tic
|
|
|
|
alg_time = a1 + a2
|
|
total_time = time() - total_tic
|
|
|
|
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
|
|
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|