From b9d6482d11f24a045dca3c1606d07085a3ade6c7 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 14 Jun 2024 16:11:50 +0800 Subject: [PATCH] add obs normalization for rl env --- tensorneat/algorithm/neat/genome/default.py | 2 +- tensorneat/examples/brax/half_cheetah.py | 13 +- tensorneat/problem/rl_env/rl_jit.py | 139 +++++++++++++++++--- 3 files changed, 135 insertions(+), 19 deletions(-) diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 2c10aa2..d742b62 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -270,7 +270,7 @@ class DefaultGenome(BaseGenome): fixed_args_output_funcs.append(f) - forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs] + forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs]) return ( symbols, diff --git a/tensorneat/examples/brax/half_cheetah.py b/tensorneat/examples/brax/half_cheetah.py index bcc515a..c23a941 100644 --- a/tensorneat/examples/brax/half_cheetah.py +++ b/tensorneat/examples/brax/half_cheetah.py @@ -1,9 +1,16 @@ +import jax + from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act + +def sample_policy(randkey, obs): + return jax.random.uniform(randkey, (6,), minval=-1, maxval=1) + + if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( @@ -17,7 +24,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), - output_transform=Act.tanh + output_transform=Act.tanh, ), pop_size=1000, species_size=10, @@ -25,6 +32,10 @@ if __name__ == "__main__": ), problem=BraxEnv( env_name="halfcheetah", + max_step=1000, + obs_normalization=True, + sample_episodes=1000, + sample_policy=sample_policy, ), generation_limit=10000, fitness_target=5000, diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index c5d54c8..1dbe8b1 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -1,8 +1,8 @@ -from functools import partial from typing import Callable import jax import jax.numpy as jnp +import numpy as np from utils import State from .. import BaseProblem @@ -17,19 +17,90 @@ class RLEnv(BaseProblem): repeat_times=1, record_episode=False, action_policy: Callable = None, + obs_normalization: bool = False, + sample_policy: Callable = None, + sample_episodes: int = 0, ): + """ + action_policy take three args: + randkey, forward_func, obs + randkey is a random key for jax.random + forward_func is a function which receive obs and return action forward_func(obs) - > action + obs is the observation of the environment + + sample_policy take two args: + randkey, obs -> action + """ + super().__init__() self.max_step = max_step self.record_episode = record_episode self.repeat_times = repeat_times self.action_policy = action_policy + if obs_normalization: + assert sample_policy is not None, "sample_policy must be provided" + assert sample_episodes > 0, "sample_size must be greater than 0" + self.sample_policy = sample_policy + self.sample_episodes = sample_episodes + self.obs_normalization = obs_normalization + + def setup(self, state=State()): + if self.obs_normalization: + print("Sampling episodes for normalization") + keys = jax.random.split(state.randkey, self.sample_episodes) + dummy_act_func = ( + lambda s, p, o: o + ) # receive state, params, obs and return the original obs + dummy_sample_func = lambda rk, act_func, obs: self.sample_policy( + rk, obs + ) # ignore act_func + + def sample(rk): + return self.evaluate_once( + state, rk, dummy_act_func, None, dummy_sample_func, True + ) + + rewards, episodes = jax.jit(jax.vmap(sample))(keys) + + obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape) + obs = obs.reshape( + -1, *self.input_shape + ) # shape: (sample_episodes * max_step, *input_shape) + + obs_axis = tuple(range(obs.ndim)) + valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:]) + obs = obs[valid_data_flag] + + obs_mean = np.mean(obs, axis=0) + obs_std = np.std(obs, axis=0) + + state = state.register( + problem_obs_mean=obs_mean, + problem_obs_std=obs_std, + ) + + print("Sampling episodes for normalization finished.") + print("valid data count: ", obs.shape[0]) + print("obs_mean: ", obs_mean) + print("obs_std: ", obs_std) + return state + def evaluate(self, state: State, randkey, act_func: Callable, params): keys = jax.random.split(randkey, self.repeat_times) if self.record_episode: rewards, episodes = jax.vmap( - self.evaluate_once, in_axes=(None, 0, None, None) - )(state, keys, act_func, params) + self.evaluate_once, in_axes=(None, 0, None, None, None, None, None) + )( + state, + keys, + act_func, + params, + self.action_policy, + True, + self.obs_normalization, + ) + episodes["obs"] = episodes["obs"].reshape( self.max_step * self.repeat_times, *self.input_shape ) @@ -43,16 +114,34 @@ class RLEnv(BaseProblem): return rewards.mean(), episodes else: - rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))( - state, keys, act_func, params + rewards = jax.vmap( + self.evaluate_once, in_axes=(None, 0, None, None, None, None, None) + )( + state, + keys, + act_func, + params, + self.action_policy, + False, + self.obs_normalization, ) + return rewards.mean() - def evaluate_once(self, state, randkey, act_func, params): + def evaluate_once( + self, + state, + randkey, + act_func, + params, + action_policy, + record_episode, + normalize_obs=False, + ): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) - if self.record_episode: + if record_episode: obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan) action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan) reward_array = jnp.full((self.max_step,), jnp.nan) @@ -65,14 +154,27 @@ class RLEnv(BaseProblem): episode = None def cond_func(carry): - _, _, _, done, _, count, _ = carry + _, _, _, done, _, count, _, rk = carry return ~done & (count < self.max_step) def body_func(carry): - obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward - if self.action_policy is not None: + ( + obs, + env_state, + rng, + done, + tr, + count, + epis, + rk, + ) = carry # tr -> total reward; rk -> randkey + + if normalize_obs: + obs = norm_obs(state, obs) + + if action_policy is not None: forward_func = lambda obs: act_func(state, params, obs) - action = self.action_policy(forward_func, obs) + action = action_policy(rk, forward_func, obs) else: action = act_func(state, params, obs) next_obs, next_env_state, reward, done, _ = self.step( @@ -80,7 +182,7 @@ class RLEnv(BaseProblem): ) next_rng, _ = jax.random.split(rng) - if self.record_episode: + if record_episode: epis["obs"] = epis["obs"].at[count].set(obs) epis["action"] = epis["action"].at[count].set(action) epis["reward"] = epis["reward"].at[count].set(reward) @@ -93,24 +195,23 @@ class RLEnv(BaseProblem): tr + reward, count + 1, epis, + jax.random.split(rk)[0], ) - _, _, _, _, total_reward, _, episode = jax.lax.while_loop( + _, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop( cond_func, body_func, - (init_obs, init_env_state, rng_episode, False, 0.0, 0, episode), + (init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey), ) - if self.record_episode: + if record_episode: return total_reward, episode else: return total_reward - # @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): return self.env_step(randkey, env_state, action) - # @partial(jax.jit, static_argnums=(0,)) def reset(self, randkey): return self.env_reset(randkey) @@ -130,3 +231,7 @@ class RLEnv(BaseProblem): def show(self, state, randkey, act_func, params, *args, **kwargs): raise NotImplementedError + + +def norm_obs(state, obs): + return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)