Files
tensorneat-mend/tensorneat/tmp.ipynb
2024-06-07 17:09:16 +08:00

222 lines
5.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:39.434327400Z",
"start_time": "2024-06-06T11:55:39.361327400Z"
}
},
"outputs": [
{
"data": {
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"a = jnp.array([\n",
" [1, 2],\n",
" [3, 4]\n",
"])\n",
"def rot_boards(board):\n",
" def rot(a, _):\n",
" a = jnp.rot90(a)\n",
" return a, a # carry, y\n",
" \n",
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
" return boards\n",
"a1 = rot_boards(a)\n",
"a2 = rot_boards(a)\n",
"\n",
"a = jnp.concatenate([a1, a2], axis=0)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = a.reshape(8, -1)\n",
"a"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:31.121054800Z",
"start_time": "2024-06-06T11:55:31.075517200Z"
}
},
"id": "639cdecea840351d"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
"def action_rot90(li):\n",
" first = li[0]\n",
" return li[1:] + [first]\n",
"\n",
"a = a\n",
"rl_flip_a = jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.417287600Z",
"start_time": "2024-06-06T11:22:36.414285500Z"
}
},
"id": "92b75cd0e870a28c"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1 2]\n",
" [3 4]] ['up', 'right', 'down', 'left']\n",
"[[2 1]\n",
" [4 3]] ['up', 'left', 'down', 'right']\n",
"[[2 4]\n",
" [1 3]] ['right', 'down', 'left', 'up']\n",
"[[1 3]\n",
" [2 4]] ['left', 'down', 'right', 'up']\n",
"[[4 3]\n",
" [2 1]] ['down', 'left', 'up', 'right']\n",
"[[3 4]\n",
" [1 2]] ['down', 'right', 'up', 'left']\n",
"[[3 1]\n",
" [4 2]] ['left', 'up', 'right', 'down']\n",
"[[4 2]\n",
" [3 1]] ['right', 'up', 'left', 'down']\n"
]
}
],
"source": [
"for i in range(4):\n",
" print(a, action)\n",
" print(rl_flip_a, lr_flip_action)\n",
" a = jnp.rot90(a)\n",
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
" action = action_rot90(action)\n",
" lr_flip_action = action_rot90(lr_flip_action)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.919614600Z",
"start_time": "2024-06-06T11:22:36.860704600Z"
}
},
"id": "55e802e0dbcc9c7f"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.rot90(a, k=2)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:12:48.186719Z",
"start_time": "2024-06-06T11:12:48.151161900Z"
}
},
"id": "16f8de3cadaa257a"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# flip left-right\n",
"jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:14:28.668195300Z",
"start_time": "2024-06-06T11:14:28.631570500Z"
}
},
"id": "1fffa4e597ab5732"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "ca53c916dcff12ae"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}