refactor folder locations

This commit is contained in:
root
2024-07-10 16:58:58 +08:00
parent 51cb4695af
commit 52d5f046d3
13 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,458 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:04.429593300Z",
"start_time": "2024-05-30T08:53:02.326728600Z"
}
},
"outputs": [],
"source": [
"import jax, jax.numpy as jnp\n",
"from tensorneat.utils import State\n",
"from problem.rl_env import BraxEnv\n",
"\n",
"\n",
"def random_policy(state: State, obs, randkey):\n",
" return jax.random.uniform(randkey, (8,)) * 2 - 1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "Array(24.975231, dtype=float32)"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single evaluation without recording episode\n",
"randkey = jax.random.key(0)\n",
"env_key, policy_key = jax.random.split(randkey)\n",
"problem = BraxEnv(env_name=\"ant\", max_step=100)\n",
"state = problem.setup()\n",
"evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
" policy_key)\n",
"score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n",
"score"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:18.928839600Z",
"start_time": "2024-05-30T08:53:04.435561800Z"
}
},
"id": "e62882e782d7e54e"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "Array([ -3.274895 , -6.016205 , -6.9032974, 9.187286 ,\n -120.19688 , 12.389805 , -4.6393256, -50.27197 ,\n 9.650737 , -73.77956 ], dtype=float32)"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch evaluation without recording episode\n",
"batch = 10\n",
"env_keys = jax.random.split(env_key, batch)\n",
"policy_keys = jax.random.split(policy_key, batch)\n",
"\n",
"score = jax.jit(\n",
" jax.vmap(\n",
" evaluate_using_random_policy_without_record, \n",
" in_axes=(None, 0, 0)\n",
" ))(\n",
" state, env_keys, policy_keys\n",
" )\n",
"score"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:29.458960600Z",
"start_time": "2024-05-30T08:53:18.928839600Z"
}
},
"id": "d01997be61038ea2"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "(Array(18.354952, dtype=float32), (100, 27), (100, 8), (100,))"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single evaluation with recording episode\n",
"randkey = jax.random.key(0)\n",
"env_key, policy_key = jax.random.split(randkey)\n",
"problem = BraxEnv(env_name=\"ant\", max_step=100, record_episode=True)\n",
"evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
" policy_key)\n",
"score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n",
"score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:40.372461Z",
"start_time": "2024-05-30T08:53:29.455962200Z"
}
},
"id": "ac6f72e21dd12ee8"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "(Array(18.354952, dtype=float32), (10, 100, 27), (10, 100, 8), (10, 100))"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch evaluation without recording episode\n",
"batch = 10\n",
"env_keys = jax.random.split(env_key, batch)\n",
"policy_keys = jax.random.split(policy_key, batch)\n",
"\n",
"scores, episodes = jax.jit(\n",
" jax.vmap(\n",
" evaluate_using_random_policy_with_record, \n",
" in_axes=(None, 0, 0)\n",
" ))(\n",
" state, env_keys, policy_keys\n",
" )\n",
"score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:51.261470500Z",
"start_time": "2024-05-30T08:53:40.368462Z"
}
},
"id": "1c55341b054ee2e8"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "Array(18.354952, dtype=float32)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n",
"evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n",
"evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
"evaluate_using_random_policy_without_record(state, env_key, policy_key)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:53:55.402886Z",
"start_time": "2024-05-30T08:53:51.255470600Z"
}
},
"id": "274ca4fd0d0b8663"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"for _ in range(20):\n",
" evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
"# 47s384ms"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:54:42.782425800Z",
"start_time": "2024-05-30T08:53:55.397887700Z"
}
},
"id": "fdb34361d19cb78d"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"for _ in range(20):\n",
" evaluate_using_random_policy_without_record(state, env_key, policy_key)\n",
"# 48s559ms"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:55:31.344699500Z",
"start_time": "2024-05-30T08:54:42.785428500Z"
}
},
"id": "9afdf6923051c9f1"
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "Array(9., dtype=float32, weak_type=True)"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single evaluation without recording episode\n",
"from problem.rl_env import GymNaxEnv\n",
"\n",
"def random_policy(state: State, obs, randkey):\n",
" return jax.random.uniform(randkey, ()) \n",
"\n",
"randkey = jax.random.key(0)\n",
"env_key, policy_key = jax.random.split(randkey)\n",
"problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500)\n",
"state = problem.setup()\n",
"evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
" policy_key)\n",
"score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n",
"score"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:58:46.652406400Z",
"start_time": "2024-05-30T08:58:45.606288800Z"
}
},
"id": "1de25fb23f519284"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [
{
"data": {
"text/plain": "Array([13., 19., 11., 12., 14., 21., 13., 11., 11., 28.], dtype=float32, weak_type=True)"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch evaluation without recording episode\n",
"batch = 10\n",
"env_keys = jax.random.split(env_key, batch)\n",
"policy_keys = jax.random.split(policy_key, batch)\n",
"\n",
"score = jax.jit(\n",
" jax.vmap(\n",
" evaluate_using_random_policy_without_record, \n",
" in_axes=(None, 0, 0)\n",
" ))(\n",
" state, env_keys, policy_keys\n",
" )\n",
"score"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:58:58.323528300Z",
"start_time": "2024-05-30T08:58:57.272024400Z"
}
},
"id": "99e745dce6f2872d"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"data": {
"text/plain": "(Array(9., dtype=float32, weak_type=True), (500, 4), (500,), (500,))"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single evaluation with recording episode\n",
"randkey = jax.random.key(0)\n",
"env_key, policy_key = jax.random.split(randkey)\n",
"problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500, record_episode=True)\n",
"evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n",
" policy_key)\n",
"score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n",
"score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:59:18.830495600Z",
"start_time": "2024-05-30T08:59:17.568087200Z"
}
},
"id": "257e340ebf24c10d"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "(Array(9., dtype=float32, weak_type=True), (10, 500, 4), (10, 500), (10, 500))"
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch evaluation without recording episode\n",
"batch = 10\n",
"env_keys = jax.random.split(env_key, batch)\n",
"policy_keys = jax.random.split(policy_key, batch)\n",
"\n",
"scores, episodes = jax.jit(\n",
" jax.vmap(\n",
" evaluate_using_random_policy_with_record, \n",
" in_axes=(None, 0, 0)\n",
" ))(\n",
" state, env_keys, policy_keys\n",
" )\n",
"score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:59:34.182539200Z",
"start_time": "2024-05-30T08:59:32.956339600Z"
}
},
"id": "9ba8dc68085cd0fc"
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [
{
"data": {
"text/plain": "Array(9., dtype=float32, weak_type=True)"
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n",
"evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n",
"evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
"evaluate_using_random_policy_without_record(state, env_key, policy_key)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T08:59:46.472504900Z",
"start_time": "2024-05-30T08:59:46.419192900Z"
}
},
"id": "ea01b6663a7ca076"
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"for _ in range(20):\n",
" evaluate_using_random_policy_with_record(state, env_key, policy_key)\n",
"# 48ms"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T09:00:18.905094200Z",
"start_time": "2024-05-30T09:00:18.809970900Z"
}
},
"id": "989c39c8e20779d0"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"for _ in range(20):\n",
" evaluate_using_random_policy_without_record(state, env_key, policy_key)\n",
"# 43ms"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T09:00:19.240415900Z",
"start_time": "2024-05-30T09:00:19.190416700Z"
}
},
"id": "bab4782fe674f2d5"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}