add obs normalization for rl env
This commit is contained in:
@@ -270,7 +270,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
fixed_args_output_funcs.append(f)
|
||||
|
||||
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
|
||||
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
|
||||
|
||||
return (
|
||||
symbols,
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import jax
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
|
||||
def sample_policy(randkey, obs):
|
||||
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
@@ -17,7 +24,7 @@ if __name__ == "__main__":
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
@@ -25,6 +32,10 @@ if __name__ == "__main__":
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="halfcheetah",
|
||||
max_step=1000,
|
||||
obs_normalization=True,
|
||||
sample_episodes=1000,
|
||||
sample_policy=sample_policy,
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from utils import State
|
||||
from .. import BaseProblem
|
||||
@@ -17,19 +17,90 @@ class RLEnv(BaseProblem):
|
||||
repeat_times=1,
|
||||
record_episode=False,
|
||||
action_policy: Callable = None,
|
||||
obs_normalization: bool = False,
|
||||
sample_policy: Callable = None,
|
||||
sample_episodes: int = 0,
|
||||
):
|
||||
"""
|
||||
action_policy take three args:
|
||||
randkey, forward_func, obs
|
||||
randkey is a random key for jax.random
|
||||
forward_func is a function which receive obs and return action forward_func(obs) - > action
|
||||
obs is the observation of the environment
|
||||
|
||||
sample_policy take two args:
|
||||
randkey, obs -> action
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
self.record_episode = record_episode
|
||||
self.repeat_times = repeat_times
|
||||
self.action_policy = action_policy
|
||||
|
||||
if obs_normalization:
|
||||
assert sample_policy is not None, "sample_policy must be provided"
|
||||
assert sample_episodes > 0, "sample_size must be greater than 0"
|
||||
self.sample_policy = sample_policy
|
||||
self.sample_episodes = sample_episodes
|
||||
self.obs_normalization = obs_normalization
|
||||
|
||||
def setup(self, state=State()):
|
||||
if self.obs_normalization:
|
||||
print("Sampling episodes for normalization")
|
||||
keys = jax.random.split(state.randkey, self.sample_episodes)
|
||||
dummy_act_func = (
|
||||
lambda s, p, o: o
|
||||
) # receive state, params, obs and return the original obs
|
||||
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
|
||||
rk, obs
|
||||
) # ignore act_func
|
||||
|
||||
def sample(rk):
|
||||
return self.evaluate_once(
|
||||
state, rk, dummy_act_func, None, dummy_sample_func, True
|
||||
)
|
||||
|
||||
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
|
||||
|
||||
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
||||
obs = obs.reshape(
|
||||
-1, *self.input_shape
|
||||
) # shape: (sample_episodes * max_step, *input_shape)
|
||||
|
||||
obs_axis = tuple(range(obs.ndim))
|
||||
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
|
||||
obs = obs[valid_data_flag]
|
||||
|
||||
obs_mean = np.mean(obs, axis=0)
|
||||
obs_std = np.std(obs, axis=0)
|
||||
|
||||
state = state.register(
|
||||
problem_obs_mean=obs_mean,
|
||||
problem_obs_std=obs_std,
|
||||
)
|
||||
|
||||
print("Sampling episodes for normalization finished.")
|
||||
print("valid data count: ", obs.shape[0])
|
||||
print("obs_mean: ", obs_mean)
|
||||
print("obs_std: ", obs_std)
|
||||
return state
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
keys = jax.random.split(randkey, self.repeat_times)
|
||||
if self.record_episode:
|
||||
rewards, episodes = jax.vmap(
|
||||
self.evaluate_once, in_axes=(None, 0, None, None)
|
||||
)(state, keys, act_func, params)
|
||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
True,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
episodes["obs"] = episodes["obs"].reshape(
|
||||
self.max_step * self.repeat_times, *self.input_shape
|
||||
)
|
||||
@@ -43,16 +114,34 @@ class RLEnv(BaseProblem):
|
||||
return rewards.mean(), episodes
|
||||
|
||||
else:
|
||||
rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))(
|
||||
state, keys, act_func, params
|
||||
rewards = jax.vmap(
|
||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
False,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
return rewards.mean()
|
||||
|
||||
def evaluate_once(self, state, randkey, act_func, params):
|
||||
def evaluate_once(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
action_policy,
|
||||
record_episode,
|
||||
normalize_obs=False,
|
||||
):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
if self.record_episode:
|
||||
if record_episode:
|
||||
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
|
||||
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
|
||||
reward_array = jnp.full((self.max_step,), jnp.nan)
|
||||
@@ -65,14 +154,27 @@ class RLEnv(BaseProblem):
|
||||
episode = None
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _, count, _ = carry
|
||||
_, _, _, done, _, count, _, rk = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
|
||||
if self.action_policy is not None:
|
||||
(
|
||||
obs,
|
||||
env_state,
|
||||
rng,
|
||||
done,
|
||||
tr,
|
||||
count,
|
||||
epis,
|
||||
rk,
|
||||
) = carry # tr -> total reward; rk -> randkey
|
||||
|
||||
if normalize_obs:
|
||||
obs = norm_obs(state, obs)
|
||||
|
||||
if action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = self.action_policy(forward_func, obs)
|
||||
action = action_policy(rk, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(
|
||||
@@ -80,7 +182,7 @@ class RLEnv(BaseProblem):
|
||||
)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
|
||||
if self.record_episode:
|
||||
if record_episode:
|
||||
epis["obs"] = epis["obs"].at[count].set(obs)
|
||||
epis["action"] = epis["action"].at[count].set(action)
|
||||
epis["reward"] = epis["reward"].at[count].set(reward)
|
||||
@@ -93,24 +195,23 @@ class RLEnv(BaseProblem):
|
||||
tr + reward,
|
||||
count + 1,
|
||||
epis,
|
||||
jax.random.split(rk)[0],
|
||||
)
|
||||
|
||||
_, _, _, _, total_reward, _, episode = jax.lax.while_loop(
|
||||
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode),
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
|
||||
)
|
||||
|
||||
if self.record_episode:
|
||||
if record_episode:
|
||||
return total_reward, episode
|
||||
else:
|
||||
return total_reward
|
||||
|
||||
# @partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
# @partial(jax.jit, static_argnums=(0,))
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
@@ -130,3 +231,7 @@ class RLEnv(BaseProblem):
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def norm_obs(state, obs):
|
||||
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)
|
||||
|
||||
Reference in New Issue
Block a user