Files
tensorneat-mend/tensorneat/problem/rl_env/rl_jit.py
2024-06-14 16:11:50 +08:00

238 lines
7.3 KiB
Python

from typing import Callable
import jax
import jax.numpy as jnp
import numpy as np
from utils import State
from .. import BaseProblem
class RLEnv(BaseProblem):
jitable = True
def __init__(
self,
max_step=1000,
repeat_times=1,
record_episode=False,
action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
sample_episodes: int = 0,
):
"""
action_policy take three args:
randkey, forward_func, obs
randkey is a random key for jax.random
forward_func is a function which receive obs and return action forward_func(obs) - > action
obs is the observation of the environment
sample_policy take two args:
randkey, obs -> action
"""
super().__init__()
self.max_step = max_step
self.record_episode = record_episode
self.repeat_times = repeat_times
self.action_policy = action_policy
if obs_normalization:
assert sample_policy is not None, "sample_policy must be provided"
assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy
self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization
def setup(self, state=State()):
if self.obs_normalization:
print("Sampling episodes for normalization")
keys = jax.random.split(state.randkey, self.sample_episodes)
dummy_act_func = (
lambda s, p, o: o
) # receive state, params, obs and return the original obs
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
rk, obs
) # ignore act_func
def sample(rk):
return self.evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
-1, *self.input_shape
) # shape: (sample_episodes * max_step, *input_shape)
obs_axis = tuple(range(obs.ndim))
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
obs = obs[valid_data_flag]
obs_mean = np.mean(obs, axis=0)
obs_std = np.std(obs, axis=0)
state = state.register(
problem_obs_mean=obs_mean,
problem_obs_std=obs_std,
)
print("Sampling episodes for normalization finished.")
print("valid data count: ", obs.shape[0])
print("obs_mean: ", obs_mean)
print("obs_std: ", obs_std)
return state
def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode:
rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
True,
self.obs_normalization,
)
episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape
)
episodes["action"] = episodes["action"].reshape(
self.max_step * self.repeat_times, *self.output_shape
)
episodes["reward"] = episodes["reward"].reshape(
self.max_step * self.repeat_times,
)
return rewards.mean(), episodes
else:
rewards = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def evaluate_once(
self,
state,
randkey,
act_func,
params,
action_policy,
record_episode,
normalize_obs=False,
):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
if record_episode:
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
reward_array = jnp.full((self.max_step,), jnp.nan)
episode = {
"obs": obs_array,
"action": action_array,
"reward": reward_array,
}
else:
episode = None
def cond_func(carry):
_, _, _, done, _, count, _, rk = carry
return ~done & (count < self.max_step)
def body_func(carry):
(
obs,
env_state,
rng,
done,
tr,
count,
epis,
rk,
) = carry # tr -> total reward; rk -> randkey
if normalize_obs:
obs = norm_obs(state, obs)
if action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = action_policy(rk, forward_func, obs)
else:
action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action
)
next_rng, _ = jax.random.split(rng)
if record_episode:
epis["obs"] = epis["obs"].at[count].set(obs)
epis["action"] = epis["action"].at[count].set(action)
epis["reward"] = epis["reward"].at[count].set(reward)
return (
next_obs,
next_env_state,
next_rng,
done,
tr + reward,
count + 1,
epis,
jax.random.split(rk)[0],
)
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
)
if record_episode:
return total_reward, episode
else:
return total_reward
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError
def norm_obs(state, obs):
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)