add more rl task in examples
This commit is contained in:
@@ -6,7 +6,6 @@ import numpy as np
|
||||
|
||||
from config import Config, HyperNeatConfig
|
||||
from core import Algorithm, Substrate, State, Genome, Gene
|
||||
from utils import Act, Agg
|
||||
from .substrate import analysis_substrate
|
||||
from algorithm import NEAT
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from .state import State
|
||||
|
||||
class Problem:
|
||||
|
||||
jitable: bool
|
||||
|
||||
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
|
||||
self.config = problem_config
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.func_fit import XOR, FuncFitConfig
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
|
||||
39
examples/gymnax/acrobot.py
Normal file
39
examples/gymnax/acrobot.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=0,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=6,
|
||||
outputs=3,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='Acrobot-v1',
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
39
examples/gymnax/mountain_car.py
Normal file
39
examples/gymnax/mountain_car.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=0,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=2,
|
||||
outputs=3,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='MountainCar-v0',
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
38
examples/gymnax/mountain_car_continuous.py
Normal file
38
examples/gymnax/mountain_car_continuous.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=100,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=2,
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='MountainCarContinuous-v0'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
39
examples/gymnax/pendulum.py
Normal file
39
examples/gymnax/pendulum.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=0,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=3,
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='Pendulum-v1',
|
||||
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
36
examples/gymnax/reacher.py
Normal file
36
examples/gymnax/reacher.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=8,
|
||||
outputs=2,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='Reacher-misc',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
120
pipeline_jitable_env.py
Normal file
120
pipeline_jitable_env.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
pipeline for jitable env like func_fit, gymnax
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import NEAT, HyperNEAT
|
||||
from config import Config
|
||||
from core import State, Algorithm, Problem
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
|
||||
|
||||
assert problem_type.jitable, "problem must be jitable"
|
||||
|
||||
self.config = config
|
||||
self.algorithm = algorithm
|
||||
self.problem = problem_type(config.problem)
|
||||
|
||||
if isinstance(algorithm, NEAT):
|
||||
assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
|
||||
|
||||
elif isinstance(algorithm, HyperNEAT):
|
||||
assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.act_func = self.algorithm.act
|
||||
|
||||
for _ in range(len(self.problem.input_shape) - 1):
|
||||
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
|
||||
|
||||
self.best_genome = None
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = None
|
||||
|
||||
def setup(self):
|
||||
key = jax.random.PRNGKey(self.config.basic.seed)
|
||||
algorithm_key, evaluate_key = jax.random.split(key, 2)
|
||||
state = State()
|
||||
state = self.algorithm.setup(algorithm_key, state)
|
||||
return state.update(
|
||||
evaluate_key=evaluate_key
|
||||
)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, state):
|
||||
|
||||
key, sub_key = jax.random.split(state.evaluate_key)
|
||||
keys = jax.random.split(key, self.config.basic.pop_size)
|
||||
|
||||
pop = self.algorithm.ask(state)
|
||||
|
||||
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
|
||||
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
|
||||
pop_transformed)
|
||||
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
|
||||
return state.update(evaluate_key=sub_key), fitnesses
|
||||
|
||||
def auto_run(self, ini_state):
|
||||
state = ini_state
|
||||
for _ in range(self.config.basic.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
|
||||
state, fitnesses = self.step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config.basic.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[max_idx]
|
||||
|
||||
member_count = jax.device_get(state.species_info.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(f"Generation: {state.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||
|
||||
def show(self, state, genome):
|
||||
transformed = self.algorithm.transform(state, genome)
|
||||
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
|
||||
|
||||
def pre_compile(self, state):
|
||||
tic = time.time()
|
||||
print("start compile")
|
||||
self.step.lower(self, state).compile()
|
||||
print(f"compile finished, cost time: {time.time() - tic}s")
|
||||
@@ -1,4 +1,3 @@
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
@@ -44,7 +43,6 @@ class Pipeline:
|
||||
evaluate_key=evaluate_key
|
||||
)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, state):
|
||||
|
||||
key, sub_key = jax.random.split(state.evaluate_key)
|
||||
@@ -110,6 +108,4 @@ class Pipeline:
|
||||
tic = time.time()
|
||||
print("start compile")
|
||||
self.step.lower(self, state).compile()
|
||||
# compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile()
|
||||
# self.__dict__['step'] = compiled_step
|
||||
print(f"compile finished, cost time: {time.time() - tic}s")
|
||||
@@ -17,6 +17,8 @@ class FuncFitConfig(ProblemConfig):
|
||||
|
||||
class FuncFit(Problem):
|
||||
|
||||
jitable = True
|
||||
|
||||
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
40
problem/rl_env/gym_env.py
Normal file
40
problem/rl_env/gym_env.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import gym
|
||||
|
||||
from core import State
|
||||
from .rl_unjit import RLEnv, RLEnvConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GymConfig(RLEnvConfig):
|
||||
env_name: str = "CartPole-v1"
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered"
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
|
||||
def __init__(self, config: GymConfig = GymConfig()):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.env, self.env_params = gym.make(config.env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
return self.env.step(randkey, env_state, action, self.env_params)
|
||||
|
||||
def env_reset(self, randkey):
|
||||
return self.env.reset(randkey, self.env_params)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.env.observation_space(self.env_params).shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
@@ -1,12 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import gymnax
|
||||
|
||||
from core import State
|
||||
from .rl_env import RLEnv, RLEnvConfig
|
||||
from .rl_jit import RLEnv, RLEnvConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -16,6 +16,8 @@ class RLEnvConfig(ProblemConfig):
|
||||
|
||||
class RLEnv(Problem):
|
||||
|
||||
jitable = True
|
||||
|
||||
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
69
problem/rl_env/rl_unjit.py
Normal file
69
problem/rl_env/rl_unjit.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
|
||||
from config import ProblemConfig
|
||||
|
||||
from core import Problem, State
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RLEnvConfig(ProblemConfig):
|
||||
output_transform: Callable = lambda x: x
|
||||
|
||||
|
||||
class RLEnv(Problem):
|
||||
|
||||
jitable = False
|
||||
|
||||
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _ = carry
|
||||
return ~done
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, _, tr = carry # total reward
|
||||
net_out = act_func(state, obs, params)
|
||||
action = self.config.output_transform(net_out)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def env_reset(self, randkey):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user