add obs normalization for rl env

This commit is contained in:
wls2002
2024-06-14 16:11:50 +08:00
parent aac9f4c3fb
commit b9d6482d11
3 changed files with 135 additions and 19 deletions

View File

@@ -270,7 +270,7 @@ class DefaultGenome(BaseGenome):
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
return (
symbols,

View File

@@ -1,9 +1,16 @@
import jax
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
def sample_policy(randkey, obs):
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
@@ -17,7 +24,7 @@ if __name__ == "__main__":
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
@@ -25,6 +32,10 @@ if __name__ == "__main__":
),
problem=BraxEnv(
env_name="halfcheetah",
max_step=1000,
obs_normalization=True,
sample_episodes=1000,
sample_policy=sample_policy,
),
generation_limit=10000,
fitness_target=5000,

View File

@@ -1,8 +1,8 @@
from functools import partial
from typing import Callable
import jax
import jax.numpy as jnp
import numpy as np
from utils import State
from .. import BaseProblem
@@ -17,19 +17,90 @@ class RLEnv(BaseProblem):
repeat_times=1,
record_episode=False,
action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
sample_episodes: int = 0,
):
"""
action_policy take three args:
randkey, forward_func, obs
randkey is a random key for jax.random
forward_func is a function which receive obs and return action forward_func(obs) - > action
obs is the observation of the environment
sample_policy take two args:
randkey, obs -> action
"""
super().__init__()
self.max_step = max_step
self.record_episode = record_episode
self.repeat_times = repeat_times
self.action_policy = action_policy
if obs_normalization:
assert sample_policy is not None, "sample_policy must be provided"
assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy
self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization
def setup(self, state=State()):
if self.obs_normalization:
print("Sampling episodes for normalization")
keys = jax.random.split(state.randkey, self.sample_episodes)
dummy_act_func = (
lambda s, p, o: o
) # receive state, params, obs and return the original obs
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
rk, obs
) # ignore act_func
def sample(rk):
return self.evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
-1, *self.input_shape
) # shape: (sample_episodes * max_step, *input_shape)
obs_axis = tuple(range(obs.ndim))
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
obs = obs[valid_data_flag]
obs_mean = np.mean(obs, axis=0)
obs_std = np.std(obs, axis=0)
state = state.register(
problem_obs_mean=obs_mean,
problem_obs_std=obs_std,
)
print("Sampling episodes for normalization finished.")
print("valid data count: ", obs.shape[0])
print("obs_mean: ", obs_mean)
print("obs_std: ", obs_std)
return state
def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode:
rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None)
)(state, keys, act_func, params)
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
True,
self.obs_normalization,
)
episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape
)
@@ -43,16 +114,34 @@ class RLEnv(BaseProblem):
return rewards.mean(), episodes
else:
rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))(
state, keys, act_func, params
rewards = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def evaluate_once(self, state, randkey, act_func, params):
def evaluate_once(
self,
state,
randkey,
act_func,
params,
action_policy,
record_episode,
normalize_obs=False,
):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
if self.record_episode:
if record_episode:
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
reward_array = jnp.full((self.max_step,), jnp.nan)
@@ -65,14 +154,27 @@ class RLEnv(BaseProblem):
episode = None
def cond_func(carry):
_, _, _, done, _, count, _ = carry
_, _, _, done, _, count, _, rk = carry
return ~done & (count < self.max_step)
def body_func(carry):
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
if self.action_policy is not None:
(
obs,
env_state,
rng,
done,
tr,
count,
epis,
rk,
) = carry # tr -> total reward; rk -> randkey
if normalize_obs:
obs = norm_obs(state, obs)
if action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(forward_func, obs)
action = action_policy(rk, forward_func, obs)
else:
action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step(
@@ -80,7 +182,7 @@ class RLEnv(BaseProblem):
)
next_rng, _ = jax.random.split(rng)
if self.record_episode:
if record_episode:
epis["obs"] = epis["obs"].at[count].set(obs)
epis["action"] = epis["action"].at[count].set(action)
epis["reward"] = epis["reward"].at[count].set(reward)
@@ -93,24 +195,23 @@ class RLEnv(BaseProblem):
tr + reward,
count + 1,
epis,
jax.random.split(rk)[0],
)
_, _, _, _, total_reward, _, episode = jax.lax.while_loop(
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode),
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
)
if self.record_episode:
if record_episode:
return total_reward, episode
else:
return total_reward
# @partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
# @partial(jax.jit, static_argnums=(0,))
def reset(self, randkey):
return self.env_reset(randkey)
@@ -130,3 +231,7 @@ class RLEnv(BaseProblem):
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError
def norm_obs(state, obs):
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)