add more rl task in examples

This commit is contained in:
wls2002
2023-08-09 18:01:21 +08:00
parent af54db3b12
commit 3b6fe7eadc
18 changed files with 431 additions and 12 deletions

View File

@@ -6,7 +6,6 @@ import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Act, Agg
from .substrate import analysis_substrate
from algorithm import NEAT

View File

@@ -6,6 +6,8 @@ from .state import State
class Problem:
jitable: bool
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
self.config = problem_config

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline import Pipeline
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline import Pipeline
from pipeline_jitable_env import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline import Pipeline
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -0,0 +1,39 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=6,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Acrobot-v1',
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -0,0 +1,39 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='MountainCar-v0',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2}
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,38 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=100,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='MountainCarContinuous-v0'
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,39 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,36 @@
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=8,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='Reacher-misc',
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

120
pipeline_jitable_env.py Normal file
View File

@@ -0,0 +1,120 @@
"""
pipeline for jitable env like func_fit, gymnax
"""
from functools import partial
from typing import Type
import jax
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
class Pipeline:
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
assert problem_type.jitable, "problem must be jitable"
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
pop = self.algorithm.ask(state)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")

View File

@@ -1,4 +1,3 @@
from functools import partial
from typing import Type
import jax
@@ -44,7 +43,6 @@ class Pipeline:
evaluate_key=evaluate_key
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
@@ -110,6 +108,4 @@ class Pipeline:
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
# compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile()
# self.__dict__['step'] = compiled_step
print(f"compile finished, cost time: {time.time() - tic}s")

View File

@@ -17,6 +17,8 @@ class FuncFitConfig(ProblemConfig):
class FuncFit(Problem):
jitable = True
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)

40
problem/rl_env/gym_env.py Normal file
View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Callable
import gym
from core import State
from .rl_unjit import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class GymConfig(RLEnvConfig):
env_name: str = "CartPole-v1"
def __post_init__(self):
assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered"
class GymNaxEnv(RLEnv):
def __init__(self, config: GymConfig = GymConfig()):
super().__init__(config)
self.config = config
self.env, self.env_params = gym.make(config.env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Callable
import jax
import jax.numpy as jnp
import gymnax
from core import State
from .rl_env import RLEnv, RLEnvConfig
from .rl_jit import RLEnv, RLEnvConfig
@dataclass(frozen=True)

View File

@@ -16,6 +16,8 @@ class RLEnvConfig(ProblemConfig):
class RLEnv(Problem):
jitable = True
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config

View File

@@ -0,0 +1,69 @@
from dataclasses import dataclass
from typing import Callable
import jax
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class RLEnvConfig(ProblemConfig):
output_transform: Callable = lambda x: x
class RLEnv(Problem):
jitable = False
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config
def evaluate(self, randkey, state: State, act_func: Callable, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
_, _, _, _, total_reward = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0)
)
return total_reward
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError