add obs normalization for rl env

This commit is contained in:
wls2002
2024-06-14 16:11:50 +08:00
parent aac9f4c3fb
commit b9d6482d11
3 changed files with 135 additions and 19 deletions

View File

@@ -270,7 +270,7 @@ class DefaultGenome(BaseGenome):
fixed_args_output_funcs.append(f) fixed_args_output_funcs.append(f)
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs] forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
return ( return (
symbols, symbols,

View File

@@ -1,9 +1,16 @@
import jax
from pipeline import Pipeline from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
def sample_policy(randkey, obs):
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
@@ -17,7 +24,7 @@ if __name__ == "__main__":
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
), ),
output_transform=Act.tanh output_transform=Act.tanh,
), ),
pop_size=1000, pop_size=1000,
species_size=10, species_size=10,
@@ -25,6 +32,10 @@ if __name__ == "__main__":
), ),
problem=BraxEnv( problem=BraxEnv(
env_name="halfcheetah", env_name="halfcheetah",
max_step=1000,
obs_normalization=True,
sample_episodes=1000,
sample_policy=sample_policy,
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000, fitness_target=5000,

View File

@@ -1,8 +1,8 @@
from functools import partial
from typing import Callable from typing import Callable
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from utils import State from utils import State
from .. import BaseProblem from .. import BaseProblem
@@ -17,19 +17,90 @@ class RLEnv(BaseProblem):
repeat_times=1, repeat_times=1,
record_episode=False, record_episode=False,
action_policy: Callable = None, action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
sample_episodes: int = 0,
): ):
"""
action_policy take three args:
randkey, forward_func, obs
randkey is a random key for jax.random
forward_func is a function which receive obs and return action forward_func(obs) - > action
obs is the observation of the environment
sample_policy take two args:
randkey, obs -> action
"""
super().__init__() super().__init__()
self.max_step = max_step self.max_step = max_step
self.record_episode = record_episode self.record_episode = record_episode
self.repeat_times = repeat_times self.repeat_times = repeat_times
self.action_policy = action_policy self.action_policy = action_policy
if obs_normalization:
assert sample_policy is not None, "sample_policy must be provided"
assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy
self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization
def setup(self, state=State()):
if self.obs_normalization:
print("Sampling episodes for normalization")
keys = jax.random.split(state.randkey, self.sample_episodes)
dummy_act_func = (
lambda s, p, o: o
) # receive state, params, obs and return the original obs
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
rk, obs
) # ignore act_func
def sample(rk):
return self.evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
-1, *self.input_shape
) # shape: (sample_episodes * max_step, *input_shape)
obs_axis = tuple(range(obs.ndim))
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
obs = obs[valid_data_flag]
obs_mean = np.mean(obs, axis=0)
obs_std = np.std(obs, axis=0)
state = state.register(
problem_obs_mean=obs_mean,
problem_obs_std=obs_std,
)
print("Sampling episodes for normalization finished.")
print("valid data count: ", obs.shape[0])
print("obs_mean: ", obs_mean)
print("obs_std: ", obs_std)
return state
def evaluate(self, state: State, randkey, act_func: Callable, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times) keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode: if self.record_episode:
rewards, episodes = jax.vmap( rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None) self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(state, keys, act_func, params) )(
state,
keys,
act_func,
params,
self.action_policy,
True,
self.obs_normalization,
)
episodes["obs"] = episodes["obs"].reshape( episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape self.max_step * self.repeat_times, *self.input_shape
) )
@@ -43,16 +114,34 @@ class RLEnv(BaseProblem):
return rewards.mean(), episodes return rewards.mean(), episodes
else: else:
rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))( rewards = jax.vmap(
state, keys, act_func, params self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
) )
return rewards.mean() return rewards.mean()
def evaluate_once(self, state, randkey, act_func, params): def evaluate_once(
self,
state,
randkey,
act_func,
params,
action_policy,
record_episode,
normalize_obs=False,
):
rng_reset, rng_episode = jax.random.split(randkey) rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset) init_obs, init_env_state = self.reset(rng_reset)
if self.record_episode: if record_episode:
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan) obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan) action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
reward_array = jnp.full((self.max_step,), jnp.nan) reward_array = jnp.full((self.max_step,), jnp.nan)
@@ -65,14 +154,27 @@ class RLEnv(BaseProblem):
episode = None episode = None
def cond_func(carry): def cond_func(carry):
_, _, _, done, _, count, _ = carry _, _, _, done, _, count, _, rk = carry
return ~done & (count < self.max_step) return ~done & (count < self.max_step)
def body_func(carry): def body_func(carry):
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward (
if self.action_policy is not None: obs,
env_state,
rng,
done,
tr,
count,
epis,
rk,
) = carry # tr -> total reward; rk -> randkey
if normalize_obs:
obs = norm_obs(state, obs)
if action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs) forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(forward_func, obs) action = action_policy(rk, forward_func, obs)
else: else:
action = act_func(state, params, obs) action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step( next_obs, next_env_state, reward, done, _ = self.step(
@@ -80,7 +182,7 @@ class RLEnv(BaseProblem):
) )
next_rng, _ = jax.random.split(rng) next_rng, _ = jax.random.split(rng)
if self.record_episode: if record_episode:
epis["obs"] = epis["obs"].at[count].set(obs) epis["obs"] = epis["obs"].at[count].set(obs)
epis["action"] = epis["action"].at[count].set(action) epis["action"] = epis["action"].at[count].set(action)
epis["reward"] = epis["reward"].at[count].set(reward) epis["reward"] = epis["reward"].at[count].set(reward)
@@ -93,24 +195,23 @@ class RLEnv(BaseProblem):
tr + reward, tr + reward,
count + 1, count + 1,
epis, epis,
jax.random.split(rk)[0],
) )
_, _, _, _, total_reward, _, episode = jax.lax.while_loop( _, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
cond_func, cond_func,
body_func, body_func,
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode), (init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
) )
if self.record_episode: if record_episode:
return total_reward, episode return total_reward, episode
else: else:
return total_reward return total_reward
# @partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action): def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action) return self.env_step(randkey, env_state, action)
# @partial(jax.jit, static_argnums=(0,))
def reset(self, randkey): def reset(self, randkey):
return self.env_reset(randkey) return self.env_reset(randkey)
@@ -130,3 +231,7 @@ class RLEnv(BaseProblem):
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def norm_obs(state, obs):
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)