add args record_episode in rl tasks, with related test "test_record_episode.ipynb";
add args return_data in func_fit tasks.
This commit is contained in:
@@ -20,14 +20,14 @@ if __name__ == "__main__":
|
|||||||
output_transform=Act.sigmoid, # the activation function for output node
|
output_transform=Act.sigmoid, # the activation function for output node
|
||||||
mutation=DefaultMutation(
|
mutation=DefaultMutation(
|
||||||
node_add=0.05,
|
node_add=0.05,
|
||||||
conn_add=0.2,
|
conn_add=0.05,
|
||||||
node_delete=0,
|
node_delete=0,
|
||||||
conn_delete=0,
|
conn_delete=0,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=100,
|
||||||
species_size=10,
|
species_size=20,
|
||||||
compatibility_threshold=3.5,
|
compatibility_threshold=2,
|
||||||
survival_threshold=0.01, # magic
|
survival_threshold=0.01, # magic
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ from .. import BaseProblem
|
|||||||
class FuncFit(BaseProblem):
|
class FuncFit(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
def __init__(self, error_method: str = "mse"):
|
def __init__(self, error_method: str = "mse", return_data: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||||
self.error_method = error_method
|
self.error_method = error_method
|
||||||
|
self.return_data = return_data
|
||||||
|
|
||||||
def setup(self, state: State = State()):
|
def setup(self, state: State = State()):
|
||||||
return state
|
return state
|
||||||
@@ -38,6 +39,9 @@ class FuncFit(BaseProblem):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if self.return_data:
|
||||||
|
return -loss, self.inputs
|
||||||
|
else:
|
||||||
return -loss
|
return -loss
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from .func_fit import FuncFit
|
|||||||
|
|
||||||
|
|
||||||
class XOR(FuncFit):
|
class XOR(FuncFit):
|
||||||
def __init__(self, error_method: str = "mse"):
|
|
||||||
super().__init__(error_method)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
|
|||||||
@@ -4,9 +4,6 @@ from .func_fit import FuncFit
|
|||||||
|
|
||||||
|
|
||||||
class XOR3d(FuncFit):
|
class XOR3d(FuncFit):
|
||||||
def __init__(self, error_method: str = "mse"):
|
|
||||||
super().__init__(error_method)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
return np.array(
|
return np.array(
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from .rl_jit import RLEnv
|
|||||||
|
|
||||||
|
|
||||||
class BraxEnv(RLEnv):
|
class BraxEnv(RLEnv):
|
||||||
def __init__(self, max_step=1000, env_name: str = "ant", backend: str = "generalized"):
|
def __init__(self, max_step=1000, record_episode=False, env_name: str = "ant", backend: str = "generalized"):
|
||||||
super().__init__(max_step)
|
super().__init__(max_step, record_episode)
|
||||||
self.env = envs.create(env_name=env_name, backend=backend)
|
self.env = envs.create(env_name=env_name, backend=backend)
|
||||||
|
|
||||||
def env_step(self, randkey, env_state, action):
|
def env_step(self, randkey, env_state, action):
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from .rl_jit import RLEnv
|
|||||||
|
|
||||||
|
|
||||||
class GymNaxEnv(RLEnv):
|
class GymNaxEnv(RLEnv):
|
||||||
def __init__(self, env_name, max_step=1000):
|
def __init__(self, env_name, max_step=1000, record_episode=False):
|
||||||
super().__init__(max_step)
|
super().__init__(max_step, record_episode)
|
||||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
||||||
self.env, self.env_params = gymnax.make(env_name)
|
self.env, self.env_params = gymnax.make(env_name)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from .. import BaseProblem
|
from .. import BaseProblem
|
||||||
|
|
||||||
@@ -8,31 +9,63 @@ from .. import BaseProblem
|
|||||||
class RLEnv(BaseProblem):
|
class RLEnv(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
def __init__(self, max_step=1000):
|
def __init__(self, max_step=1000, record_episode=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
|
self.record_episode = record_episode
|
||||||
|
|
||||||
def evaluate(self, state, randkey, act_func, params):
|
def evaluate(self, state, randkey, act_func, params):
|
||||||
rng_reset, rng_episode = jax.random.split(randkey)
|
rng_reset, rng_episode = jax.random.split(randkey)
|
||||||
init_obs, init_env_state = self.reset(rng_reset)
|
init_obs, init_env_state = self.reset(rng_reset)
|
||||||
|
|
||||||
|
if self.record_episode:
|
||||||
|
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
|
||||||
|
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
|
||||||
|
reward_array = jnp.full((self.max_step,), jnp.nan)
|
||||||
|
episode = {
|
||||||
|
"obs": obs_array,
|
||||||
|
"action": action_array,
|
||||||
|
"reward": reward_array,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
episode = None
|
||||||
|
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
_, _, _, done, _, count = carry
|
_, _, _, done, _, count, _ = carry
|
||||||
return ~done & (count < self.max_step)
|
return ~done & (count < self.max_step)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
|
||||||
action = act_func(state, obs, params)
|
action = act_func(state, obs, params)
|
||||||
next_obs, next_env_state, reward, done, _ = self.step(
|
next_obs, next_env_state, reward, done, _ = self.step(
|
||||||
rng, env_state, action
|
rng, env_state, action
|
||||||
)
|
)
|
||||||
next_rng, _ = jax.random.split(rng)
|
next_rng, _ = jax.random.split(rng)
|
||||||
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
|
||||||
|
|
||||||
_, _, _, _, total_reward, _ = jax.lax.while_loop(
|
if self.record_episode:
|
||||||
cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
epis["obs"] = epis["obs"].at[count].set(obs)
|
||||||
|
epis["action"] = epis["action"].at[count].set(action)
|
||||||
|
epis["reward"] = epis["reward"].at[count].set(reward)
|
||||||
|
|
||||||
|
return (
|
||||||
|
next_obs,
|
||||||
|
next_env_state,
|
||||||
|
next_rng,
|
||||||
|
done,
|
||||||
|
tr + reward,
|
||||||
|
count + 1,
|
||||||
|
epis,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_, _, _, _, total_reward, _, episode = jax.lax.while_loop(
|
||||||
|
cond_func,
|
||||||
|
body_func,
|
||||||
|
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.record_episode:
|
||||||
|
return total_reward, episode
|
||||||
|
else:
|
||||||
return total_reward
|
return total_reward
|
||||||
|
|
||||||
# @partial(jax.jit, static_argnums=(0,))
|
# @partial(jax.jit, static_argnums=(0,))
|
||||||
|
|||||||
458
tensorneat/test/test_record_episode.ipynb
Normal file
458
tensorneat/test/test_record_episode.ipynb
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "initial_id",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:04.429593300Z",
|
||||||
|
"start_time": "2024-05-30T08:53:02.326728600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import jax, jax.numpy as jnp\n",
|
||||||
|
"from utils import State\n",
|
||||||
|
"from problem.rl_env import BraxEnv\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def random_policy(state: State, obs, randkey):\n",
|
||||||
|
" return jax.random.uniform(randkey, (8,)) * 2 - 1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array(24.975231, dtype=float32)"
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# single evaluation without recording episode\n",
|
||||||
|
"randkey = jax.random.key(0)\n",
|
||||||
|
"env_key, policy_key = jax.random.split(randkey)\n",
|
||||||
|
"problem = BraxEnv(env_name=\"ant\", max_step=100)\n",
|
||||||
|
"state = problem.setup()\n",
|
||||||
|
"evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
|
||||||
|
" policy_key)\n",
|
||||||
|
"score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n",
|
||||||
|
"score"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:18.928839600Z",
|
||||||
|
"start_time": "2024-05-30T08:53:04.435561800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "e62882e782d7e54e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array([ -3.274895 , -6.016205 , -6.9032974, 9.187286 ,\n -120.19688 , 12.389805 , -4.6393256, -50.27197 ,\n 9.650737 , -73.77956 ], dtype=float32)"
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# batch evaluation without recording episode\n",
|
||||||
|
"batch = 10\n",
|
||||||
|
"env_keys = jax.random.split(env_key, batch)\n",
|
||||||
|
"policy_keys = jax.random.split(policy_key, batch)\n",
|
||||||
|
"\n",
|
||||||
|
"score = jax.jit(\n",
|
||||||
|
" jax.vmap(\n",
|
||||||
|
" evaluate_using_random_policy_without_record, \n",
|
||||||
|
" in_axes=(None, 0, 0)\n",
|
||||||
|
" ))(\n",
|
||||||
|
" state, env_keys, policy_keys\n",
|
||||||
|
" )\n",
|
||||||
|
"score"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:29.458960600Z",
|
||||||
|
"start_time": "2024-05-30T08:53:18.928839600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "d01997be61038ea2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(Array(18.354952, dtype=float32), (100, 27), (100, 8), (100,))"
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# single evaluation with recording episode\n",
|
||||||
|
"randkey = jax.random.key(0)\n",
|
||||||
|
"env_key, policy_key = jax.random.split(randkey)\n",
|
||||||
|
"problem = BraxEnv(env_name=\"ant\", max_step=100, record_episode=True)\n",
|
||||||
|
"evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
|
||||||
|
" policy_key)\n",
|
||||||
|
"score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n",
|
||||||
|
"score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:40.372461Z",
|
||||||
|
"start_time": "2024-05-30T08:53:29.455962200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "ac6f72e21dd12ee8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(Array(18.354952, dtype=float32), (10, 100, 27), (10, 100, 8), (10, 100))"
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# batch evaluation without recording episode\n",
|
||||||
|
"batch = 10\n",
|
||||||
|
"env_keys = jax.random.split(env_key, batch)\n",
|
||||||
|
"policy_keys = jax.random.split(policy_key, batch)\n",
|
||||||
|
"\n",
|
||||||
|
"scores, episodes = jax.jit(\n",
|
||||||
|
" jax.vmap(\n",
|
||||||
|
" evaluate_using_random_policy_with_record, \n",
|
||||||
|
" in_axes=(None, 0, 0)\n",
|
||||||
|
" ))(\n",
|
||||||
|
" state, env_keys, policy_keys\n",
|
||||||
|
" )\n",
|
||||||
|
"score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:51.261470500Z",
|
||||||
|
"start_time": "2024-05-30T08:53:40.368462Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "1c55341b054ee2e8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array(18.354952, dtype=float32)"
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n",
|
||||||
|
"evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n",
|
||||||
|
"evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
|
||||||
|
"evaluate_using_random_policy_without_record(state, env_key, policy_key)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:53:55.402886Z",
|
||||||
|
"start_time": "2024-05-30T08:53:51.255470600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "274ca4fd0d0b8663"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for _ in range(20):\n",
|
||||||
|
" evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
|
||||||
|
"# 47s384ms"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:54:42.782425800Z",
|
||||||
|
"start_time": "2024-05-30T08:53:55.397887700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "fdb34361d19cb78d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for _ in range(20):\n",
|
||||||
|
" evaluate_using_random_policy_without_record(state, env_key, policy_key)\n",
|
||||||
|
"# 48s559ms"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:55:31.344699500Z",
|
||||||
|
"start_time": "2024-05-30T08:54:42.785428500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "9afdf6923051c9f1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array(9., dtype=float32, weak_type=True)"
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# single evaluation without recording episode\n",
|
||||||
|
"from problem.rl_env import GymNaxEnv\n",
|
||||||
|
"\n",
|
||||||
|
"def random_policy(state: State, obs, randkey):\n",
|
||||||
|
" return jax.random.uniform(randkey, ()) \n",
|
||||||
|
"\n",
|
||||||
|
"randkey = jax.random.key(0)\n",
|
||||||
|
"env_key, policy_key = jax.random.split(randkey)\n",
|
||||||
|
"problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500)\n",
|
||||||
|
"state = problem.setup()\n",
|
||||||
|
"evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
|
||||||
|
" policy_key)\n",
|
||||||
|
"score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n",
|
||||||
|
"score"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:58:46.652406400Z",
|
||||||
|
"start_time": "2024-05-30T08:58:45.606288800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "1de25fb23f519284"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array([13., 19., 11., 12., 14., 21., 13., 11., 11., 28.], dtype=float32, weak_type=True)"
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# batch evaluation without recording episode\n",
|
||||||
|
"batch = 10\n",
|
||||||
|
"env_keys = jax.random.split(env_key, batch)\n",
|
||||||
|
"policy_keys = jax.random.split(policy_key, batch)\n",
|
||||||
|
"\n",
|
||||||
|
"score = jax.jit(\n",
|
||||||
|
" jax.vmap(\n",
|
||||||
|
" evaluate_using_random_policy_without_record, \n",
|
||||||
|
" in_axes=(None, 0, 0)\n",
|
||||||
|
" ))(\n",
|
||||||
|
" state, env_keys, policy_keys\n",
|
||||||
|
" )\n",
|
||||||
|
"score"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:58:58.323528300Z",
|
||||||
|
"start_time": "2024-05-30T08:58:57.272024400Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "99e745dce6f2872d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(Array(9., dtype=float32, weak_type=True), (500, 4), (500,), (500,))"
|
||||||
|
},
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# single evaluation with recording episode\n",
|
||||||
|
"randkey = jax.random.key(0)\n",
|
||||||
|
"env_key, policy_key = jax.random.split(randkey)\n",
|
||||||
|
"problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500, record_episode=True)\n",
|
||||||
|
"evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
|
||||||
|
" policy_key)\n",
|
||||||
|
"score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n",
|
||||||
|
"score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:59:18.830495600Z",
|
||||||
|
"start_time": "2024-05-30T08:59:17.568087200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "257e340ebf24c10d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(Array(9., dtype=float32, weak_type=True), (10, 500, 4), (10, 500), (10, 500))"
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# batch evaluation without recording episode\n",
|
||||||
|
"batch = 10\n",
|
||||||
|
"env_keys = jax.random.split(env_key, batch)\n",
|
||||||
|
"policy_keys = jax.random.split(policy_key, batch)\n",
|
||||||
|
"\n",
|
||||||
|
"scores, episodes = jax.jit(\n",
|
||||||
|
" jax.vmap(\n",
|
||||||
|
" evaluate_using_random_policy_with_record, \n",
|
||||||
|
" in_axes=(None, 0, 0)\n",
|
||||||
|
" ))(\n",
|
||||||
|
" state, env_keys, policy_keys\n",
|
||||||
|
" )\n",
|
||||||
|
"score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:59:34.182539200Z",
|
||||||
|
"start_time": "2024-05-30T08:59:32.956339600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "9ba8dc68085cd0fc"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array(9., dtype=float32, weak_type=True)"
|
||||||
|
},
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n",
|
||||||
|
"evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n",
|
||||||
|
"evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
|
||||||
|
"evaluate_using_random_policy_without_record(state, env_key, policy_key)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T08:59:46.472504900Z",
|
||||||
|
"start_time": "2024-05-30T08:59:46.419192900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "ea01b6663a7ca076"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for _ in range(20):\n",
|
||||||
|
" evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
|
||||||
|
"# 48ms"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T09:00:18.905094200Z",
|
||||||
|
"start_time": "2024-05-30T09:00:18.809970900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "989c39c8e20779d0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for _ in range(20):\n",
|
||||||
|
" evaluate_using_random_policy_without_record(state, env_key, policy_key)\n",
|
||||||
|
"# 43ms"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-30T09:00:19.240415900Z",
|
||||||
|
"start_time": "2024-05-30T09:00:19.190416700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "bab4782fe674f2d5"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user