{ "cells": [ { "cell_type": "code", "execution_count": 22, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-06-06T11:55:39.434327400Z", "start_time": "2024-06-06T11:55:39.361327400Z" } }, "outputs": [ { "data": { "text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)" }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax, jax.numpy as jnp\n", "a = jnp.array([\n", " [1, 2],\n", " [3, 4]\n", "])\n", "def rot_boards(board):\n", " def rot(a, _):\n", " a = jnp.rot90(a)\n", " return a, a # carry, y\n", " \n", " _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n", " return boards\n", "a1 = rot_boards(a)\n", "a2 = rot_boards(a)\n", "\n", "a = jnp.concatenate([a1, a2], axis=0)\n", "a" ] }, { "cell_type": "code", "execution_count": 21, "outputs": [ { "data": { "text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)" }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = a.reshape(8, -1)\n", "a" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T11:55:31.121054800Z", "start_time": "2024-06-06T11:55:31.075517200Z" } }, "id": "639cdecea840351d" }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "action = [\"up\", \"right\", \"down\", \"left\"]\n", "lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n", "def action_rot90(li):\n", " first = li[0]\n", " return li[1:] + [first]\n", "\n", "a = a\n", "rl_flip_a = jnp.fliplr(a)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T11:22:36.417287600Z", "start_time": "2024-06-06T11:22:36.414285500Z" } }, "id": "92b75cd0e870a28c" }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1 2]\n", " [3 4]] ['up', 'right', 'down', 'left']\n", "[[2 1]\n", " [4 3]] ['up', 'left', 'down', 'right']\n", "[[2 4]\n", " [1 3]] ['right', 'down', 'left', 'up']\n", "[[1 3]\n", " [2 4]] ['left', 'down', 'right', 'up']\n", "[[4 3]\n", " [2 1]] ['down', 'left', 'up', 'right']\n", "[[3 4]\n", " [1 2]] ['down', 'right', 'up', 'left']\n", "[[3 1]\n", " [4 2]] ['left', 'up', 'right', 'down']\n", "[[4 2]\n", " [3 1]] ['right', 'up', 'left', 'down']\n" ] } ], "source": [ "for i in range(4):\n", " print(a, action)\n", " print(rl_flip_a, lr_flip_action)\n", " a = jnp.rot90(a)\n", " rl_flip_a = jnp.rot90(rl_flip_a)\n", " action = action_rot90(action)\n", " lr_flip_action = action_rot90(lr_flip_action)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T11:22:36.919614600Z", "start_time": "2024-06-06T11:22:36.860704600Z" } }, "id": "55e802e0dbcc9c7f" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jnp.rot90(a, k=2)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T11:12:48.186719Z", "start_time": "2024-06-06T11:12:48.151161900Z" } }, "id": "16f8de3cadaa257a" }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "data": { "text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)" }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# flip left-right\n", "jnp.fliplr(a)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T11:14:28.668195300Z", "start_time": "2024-06-06T11:14:28.631570500Z" } }, "id": "1fffa4e597ab5732" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "ca53c916dcff12ae" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }