222 lines
5.3 KiB
Plaintext
222 lines
5.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "initial_id",
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:55:39.434327400Z",
|
|
"start_time": "2024-06-06T11:55:39.361327400Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import jax, jax.numpy as jnp\n",
|
|
"a = jnp.array([\n",
|
|
" [1, 2],\n",
|
|
" [3, 4]\n",
|
|
"])\n",
|
|
"def rot_boards(board):\n",
|
|
" def rot(a, _):\n",
|
|
" a = jnp.rot90(a)\n",
|
|
" return a, a # carry, y\n",
|
|
" \n",
|
|
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
|
|
" return boards\n",
|
|
"a1 = rot_boards(a)\n",
|
|
"a2 = rot_boards(a)\n",
|
|
"\n",
|
|
"a = jnp.concatenate([a1, a2], axis=0)\n",
|
|
"a"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"a = a.reshape(8, -1)\n",
|
|
"a"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:55:31.121054800Z",
|
|
"start_time": "2024-06-06T11:55:31.075517200Z"
|
|
}
|
|
},
|
|
"id": "639cdecea840351d"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"outputs": [],
|
|
"source": [
|
|
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
|
|
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
|
|
"def action_rot90(li):\n",
|
|
" first = li[0]\n",
|
|
" return li[1:] + [first]\n",
|
|
"\n",
|
|
"a = a\n",
|
|
"rl_flip_a = jnp.fliplr(a)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:22:36.417287600Z",
|
|
"start_time": "2024-06-06T11:22:36.414285500Z"
|
|
}
|
|
},
|
|
"id": "92b75cd0e870a28c"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[1 2]\n",
|
|
" [3 4]] ['up', 'right', 'down', 'left']\n",
|
|
"[[2 1]\n",
|
|
" [4 3]] ['up', 'left', 'down', 'right']\n",
|
|
"[[2 4]\n",
|
|
" [1 3]] ['right', 'down', 'left', 'up']\n",
|
|
"[[1 3]\n",
|
|
" [2 4]] ['left', 'down', 'right', 'up']\n",
|
|
"[[4 3]\n",
|
|
" [2 1]] ['down', 'left', 'up', 'right']\n",
|
|
"[[3 4]\n",
|
|
" [1 2]] ['down', 'right', 'up', 'left']\n",
|
|
"[[3 1]\n",
|
|
" [4 2]] ['left', 'up', 'right', 'down']\n",
|
|
"[[4 2]\n",
|
|
" [3 1]] ['right', 'up', 'left', 'down']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for i in range(4):\n",
|
|
" print(a, action)\n",
|
|
" print(rl_flip_a, lr_flip_action)\n",
|
|
" a = jnp.rot90(a)\n",
|
|
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
|
|
" action = action_rot90(action)\n",
|
|
" lr_flip_action = action_rot90(lr_flip_action)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:22:36.919614600Z",
|
|
"start_time": "2024-06-06T11:22:36.860704600Z"
|
|
}
|
|
},
|
|
"id": "55e802e0dbcc9c7f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jnp.rot90(a, k=2)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:12:48.186719Z",
|
|
"start_time": "2024-06-06T11:12:48.151161900Z"
|
|
}
|
|
},
|
|
"id": "16f8de3cadaa257a"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# flip left-right\n",
|
|
"jnp.fliplr(a)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-06T11:14:28.668195300Z",
|
|
"start_time": "2024-06-06T11:14:28.631570500Z"
|
|
}
|
|
},
|
|
"id": "1fffa4e597ab5732"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "ca53c916dcff12ae"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|