from typing import Callable from time import time import jax from jax import numpy as jnp, vmap, jit import gymnax import numpy as np from config import * from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv def conf_cartpole(): return Config( basic=BasicConfig( seed=42, fitness_target=500, generation_limit=150, pop_size=10000 ), neat=NeatConfig( inputs=4, outputs=2, ), gene=NormalGeneConfig( activation_default=Act.tanh, activation_options=(Act.tanh,), ), problem=GymNaxConfig( env_name='CartPole-v1', output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} ) ) def batch_evaluate( key, alg_state, genomes, env_params, batch_transform: Callable, batch_act: Callable, batch_reset: Callable, batch_step: Callable, ): alg_time, env_time, forward_time = 0, 0, 0 pop_size = genomes.nodes.shape[0] alg_tic = time() genomes_transform = batch_transform(alg_state, genomes) alg_time += time() - alg_tic reset_keys = jax.random.split(key, pop_size) observations, states = batch_reset(reset_keys, env_params) done = np.zeros(pop_size, dtype=bool) fitnesses = np.zeros(pop_size) while not np.all(done): key, _ = jax.random.split(key) vmap_keys = jax.random.split(key, pop_size) forward_tic = time() actions = batch_act(alg_state, observations, genomes_transform).block_until_ready() forward_time += time() - forward_tic env_tic = time() observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params) reward, current_done = jax.device_get([reward, current_done]) env_time += time() - env_tic fitnesses += reward * np.logical_not(done) done = np.logical_or(done, current_done) return fitnesses, alg_time, env_time, forward_time def main(): conf = conf_cartpole() algorithm = NEAT(conf, NormalGene) def act(state, inputs, genome): res = algorithm.act(state, inputs, genome) return conf.problem.output_transform(res) batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0))) # (state, obs, genome_transform) -> action batch_act = jit(vmap(act, in_axes=(None, 0, 0))) env, env_params = gymnax.make(conf.problem.env_name) # (seed, params) -> (ini_obs, ini_state) batch_reset = jit(vmap(env.reset, in_axes=(0, None))) # (seed, state, action, params) -> (obs, state, reward, done, info) batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None))) key = jax.random.PRNGKey(conf.basic.seed) alg_key, pro_key = jax.random.split(key) alg_state = algorithm.setup(alg_key) for i in range(conf.basic.generation_limit): pro_key, _ = jax.random.split(pro_key) fitnesses, a1, env_time, forward_time = batch_evaluate( pro_key, alg_state, algorithm.ask(alg_state), env_params, batch_transform, batch_act, batch_reset, batch_step ) alg_tic = time() alg_state = algorithm.tell(alg_state, fitnesses) alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state) a2 = time() - alg_tic alg_time = a1 + a2 print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, " f"max_fitness: {np.max(fitnesses):.2f}") if __name__ == '__main__': main()