{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-06-02T08:29:04.093990Z", "start_time": "2024-06-02T08:29:04.085992900Z" } }, "outputs": [], "source": [ "import jax\n", "from jax import vmap, jit, numpy as jnp" ] }, { "cell_type": "code", "execution_count": 26, "outputs": [], "source": [ "def func(x, y):\n", " return x + y\n", "\n", "def loop2():\n", " s = 0\n", " for i in range(1000):\n", " x = jnp.full((10000, 1), i)\n", " y = jnp.full((10000, 1), i + 1)\n", " s = (vmap(func)(x, y)).sum()\n", " return s\n", "\n", "def loop3():\n", " s = 0\n", " vmap_func = vmap(func)\n", " for i in range(1000):\n", " x = jnp.full((10000, 1), i)\n", " y = jnp.full((10000, 1), i + 1)\n", " s = (vmap_func(x, y)).sum()\n", " return s" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-02T08:31:13.023886300Z", "start_time": "2024-06-02T08:31:13.003026800Z" } }, "id": "39f803029127aaa8" }, { "cell_type": "code", "execution_count": 27, "outputs": [ { "data": { "text/plain": "Array(19990000, dtype=int32)" }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compile_loop = jit(loop3).lower().compile()\n", "compile_loop()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-02T08:31:14.526380100Z", "start_time": "2024-06-02T08:31:13.870916800Z" } }, "id": "ab9f83d0a313f51d" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "c1bd963e51aa5fd4" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }