{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-05-30T08:53:04.429593300Z", "start_time": "2024-05-30T08:53:02.326728600Z" } }, "outputs": [], "source": [ "import jax, jax.numpy as jnp\n", "from tensorneat.utils import State\n", "from problem.rl_env import BraxEnv\n", "\n", "\n", "def random_policy(state: State, obs, randkey):\n", " return jax.random.uniform(randkey, (8,)) * 2 - 1" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "data": { "text/plain": "Array(24.975231, dtype=float32)" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# single evaluation without recording episode\n", "randkey = jax.random.key(0)\n", "env_key, policy_key = jax.random.split(randkey)\n", "problem = BraxEnv(env_name=\"ant\", max_step=100)\n", "state = problem.setup()\n", "evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", " policy_key)\n", "score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n", "score" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:53:18.928839600Z", "start_time": "2024-05-30T08:53:04.435561800Z" } }, "id": "e62882e782d7e54e" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "data": { "text/plain": "Array([ -3.274895 , -6.016205 , -6.9032974, 9.187286 ,\n -120.19688 , 12.389805 , -4.6393256, -50.27197 ,\n 9.650737 , -73.77956 ], dtype=float32)" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch evaluation without recording episode\n", "batch = 10\n", "env_keys = jax.random.split(env_key, batch)\n", "policy_keys = jax.random.split(policy_key, batch)\n", "\n", "score = jax.jit(\n", " jax.vmap(\n", " evaluate_using_random_policy_without_record, \n", " in_axes=(None, 0, 0)\n", " ))(\n", " state, env_keys, policy_keys\n", " )\n", "score" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:53:29.458960600Z", "start_time": "2024-05-30T08:53:18.928839600Z" } }, "id": "d01997be61038ea2" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "(Array(18.354952, dtype=float32), (100, 27), (100, 8), (100,))" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# single evaluation with recording episode\n", "randkey = jax.random.key(0)\n", "env_key, policy_key = jax.random.split(randkey)\n", "problem = BraxEnv(env_name=\"ant\", max_step=100, record_episode=True)\n", "evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", " policy_key)\n", "score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n", "score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:53:40.372461Z", "start_time": "2024-05-30T08:53:29.455962200Z" } }, "id": "ac6f72e21dd12ee8" }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "(Array(18.354952, dtype=float32), (10, 100, 27), (10, 100, 8), (10, 100))" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch evaluation without recording episode\n", "batch = 10\n", "env_keys = jax.random.split(env_key, batch)\n", "policy_keys = jax.random.split(policy_key, batch)\n", "\n", "scores, episodes = jax.jit(\n", " jax.vmap(\n", " evaluate_using_random_policy_with_record, \n", " in_axes=(None, 0, 0)\n", " ))(\n", " state, env_keys, policy_keys\n", " )\n", "score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:53:51.261470500Z", "start_time": "2024-05-30T08:53:40.368462Z" } }, "id": "1c55341b054ee2e8" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "Array(18.354952, dtype=float32)" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n", "evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n", "evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", "evaluate_using_random_policy_without_record(state, env_key, policy_key)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:53:55.402886Z", "start_time": "2024-05-30T08:53:51.255470600Z" } }, "id": "274ca4fd0d0b8663" }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "for _ in range(20):\n", " evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", "# 47s384ms" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:54:42.782425800Z", "start_time": "2024-05-30T08:53:55.397887700Z" } }, "id": "fdb34361d19cb78d" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "for _ in range(20):\n", " evaluate_using_random_policy_without_record(state, env_key, policy_key)\n", "# 48s559ms" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:55:31.344699500Z", "start_time": "2024-05-30T08:54:42.785428500Z" } }, "id": "9afdf6923051c9f1" }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "data": { "text/plain": "Array(9., dtype=float32, weak_type=True)" }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# single evaluation without recording episode\n", "from problem.rl_env import GymNaxEnv\n", "\n", "def random_policy(state: State, obs, randkey):\n", " return jax.random.uniform(randkey, ()) \n", "\n", "randkey = jax.random.key(0)\n", "env_key, policy_key = jax.random.split(randkey)\n", "problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500)\n", "state = problem.setup()\n", "evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", " policy_key)\n", "score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n", "score" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:58:46.652406400Z", "start_time": "2024-05-30T08:58:45.606288800Z" } }, "id": "1de25fb23f519284" }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "data": { "text/plain": "Array([13., 19., 11., 12., 14., 21., 13., 11., 11., 28.], dtype=float32, weak_type=True)" }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch evaluation without recording episode\n", "batch = 10\n", "env_keys = jax.random.split(env_key, batch)\n", "policy_keys = jax.random.split(policy_key, batch)\n", "\n", "score = jax.jit(\n", " jax.vmap(\n", " evaluate_using_random_policy_without_record, \n", " in_axes=(None, 0, 0)\n", " ))(\n", " state, env_keys, policy_keys\n", " )\n", "score" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:58:58.323528300Z", "start_time": "2024-05-30T08:58:57.272024400Z" } }, "id": "99e745dce6f2872d" }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "data": { "text/plain": "(Array(9., dtype=float32, weak_type=True), (500, 4), (500,), (500,))" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# single evaluation with recording episode\n", "randkey = jax.random.key(0)\n", "env_key, policy_key = jax.random.split(randkey)\n", "problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500, record_episode=True)\n", "evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", " policy_key)\n", "score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n", "score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:59:18.830495600Z", "start_time": "2024-05-30T08:59:17.568087200Z" } }, "id": "257e340ebf24c10d" }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "data": { "text/plain": "(Array(9., dtype=float32, weak_type=True), (10, 500, 4), (10, 500), (10, 500))" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch evaluation without recording episode\n", "batch = 10\n", "env_keys = jax.random.split(env_key, batch)\n", "policy_keys = jax.random.split(policy_key, batch)\n", "\n", "scores, episodes = jax.jit(\n", " jax.vmap(\n", " evaluate_using_random_policy_with_record, \n", " in_axes=(None, 0, 0)\n", " ))(\n", " state, env_keys, policy_keys\n", " )\n", "score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:59:34.182539200Z", "start_time": "2024-05-30T08:59:32.956339600Z" } }, "id": "9ba8dc68085cd0fc" }, { "cell_type": "code", "execution_count": 16, "outputs": [ { "data": { "text/plain": "Array(9., dtype=float32, weak_type=True)" }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n", "evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n", "evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", "evaluate_using_random_policy_without_record(state, env_key, policy_key)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T08:59:46.472504900Z", "start_time": "2024-05-30T08:59:46.419192900Z" } }, "id": "ea01b6663a7ca076" }, { "cell_type": "code", "execution_count": 19, "outputs": [], "source": [ "for _ in range(20):\n", " evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", "# 48ms" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T09:00:18.905094200Z", "start_time": "2024-05-30T09:00:18.809970900Z" } }, "id": "989c39c8e20779d0" }, { "cell_type": "code", "execution_count": 20, "outputs": [], "source": [ "for _ in range(20):\n", " evaluate_using_random_policy_without_record(state, env_key, policy_key)\n", "# 43ms" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T09:00:19.240415900Z", "start_time": "2024-05-30T09:00:19.190416700Z" } }, "id": "bab4782fe674f2d5" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }