add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
@@ -2,189 +2,110 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:55:39.434327400Z",
|
||||
"start_time": "2024-06-06T11:55:39.361327400Z"
|
||||
}
|
||||
},
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
|
||||
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"a = jnp.array([\n",
|
||||
" [1, 2],\n",
|
||||
" [3, 4]\n",
|
||||
"])\n",
|
||||
"def rot_boards(board):\n",
|
||||
" def rot(a, _):\n",
|
||||
" a = jnp.rot90(a)\n",
|
||||
" return a, a # carry, y\n",
|
||||
" \n",
|
||||
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
|
||||
" return boards\n",
|
||||
"a1 = rot_boards(a)\n",
|
||||
"a2 = rot_boards(a)\n",
|
||||
"\n",
|
||||
"a = jnp.concatenate([a1, a2], axis=0)\n",
|
||||
"a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = a.reshape(8, -1)\n",
|
||||
"a"
|
||||
"from algorithm.neat import *\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"genome = DefaultGenome(\n",
|
||||
" num_inputs=27,\n",
|
||||
" num_outputs=8,\n",
|
||||
" max_nodes=100,\n",
|
||||
" max_conns=200,\n",
|
||||
" node_gene=DefaultNodeGene(\n",
|
||||
" activation_options=(Act.tanh,),\n",
|
||||
" activation_default=Act.tanh,\n",
|
||||
" ),\n",
|
||||
" output_transform=Act.tanh,\n",
|
||||
")\n",
|
||||
"state = genome.setup()\n",
|
||||
"genome"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:55:31.121054800Z",
|
||||
"start_time": "2024-06-06T11:55:31.075517200Z"
|
||||
"end_time": "2024-06-09T12:08:22.569123400Z",
|
||||
"start_time": "2024-06-09T12:08:19.331863800Z"
|
||||
}
|
||||
},
|
||||
"id": "639cdecea840351d"
|
||||
"id": "b2b214a5454c4814"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state = state.register(data=jnp.zeros((1, 27)))\n",
|
||||
"# try to save the genome object\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"with open('genome.pkl', 'wb') as f:\n",
|
||||
" genome.__dict__[\"state\"] = state\n",
|
||||
" pickle.dump(genome, f)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:09:01.943445900Z",
|
||||
"start_time": "2024-06-09T12:09:01.919416Z"
|
||||
}
|
||||
},
|
||||
"id": "28348dfc458e8473"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
|
||||
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
|
||||
"def action_rot90(li):\n",
|
||||
" first = li[0]\n",
|
||||
" return li[1:] + [first]\n",
|
||||
"\n",
|
||||
"a = a\n",
|
||||
"rl_flip_a = jnp.fliplr(a)"
|
||||
"# try to load the genome object\n",
|
||||
"with open('genome.pkl', 'rb') as f:\n",
|
||||
" genome = pickle.load(f)\n",
|
||||
" state = genome.state\n",
|
||||
" del genome.__dict__[\"state\"]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:22:36.417287600Z",
|
||||
"start_time": "2024-06-06T11:22:36.414285500Z"
|
||||
"end_time": "2024-06-09T12:10:28.621539400Z",
|
||||
"start_time": "2024-06-09T12:10:28.612540100Z"
|
||||
}
|
||||
},
|
||||
"id": "92b75cd0e870a28c"
|
||||
"id": "c91be9fe3d2b5d5d"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[1 2]\n",
|
||||
" [3 4]] ['up', 'right', 'down', 'left']\n",
|
||||
"[[2 1]\n",
|
||||
" [4 3]] ['up', 'left', 'down', 'right']\n",
|
||||
"[[2 4]\n",
|
||||
" [1 3]] ['right', 'down', 'left', 'up']\n",
|
||||
"[[1 3]\n",
|
||||
" [2 4]] ['left', 'down', 'right', 'up']\n",
|
||||
"[[4 3]\n",
|
||||
" [2 1]] ['down', 'left', 'up', 'right']\n",
|
||||
"[[3 4]\n",
|
||||
" [1 2]] ['down', 'right', 'up', 'left']\n",
|
||||
"[[3 1]\n",
|
||||
" [4 2]] ['left', 'up', 'right', 'down']\n",
|
||||
"[[4 2]\n",
|
||||
" [3 1]] ['right', 'up', 'left', 'down']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in range(4):\n",
|
||||
" print(a, action)\n",
|
||||
" print(rl_flip_a, lr_flip_action)\n",
|
||||
" a = jnp.rot90(a)\n",
|
||||
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
|
||||
" action = action_rot90(action)\n",
|
||||
" lr_flip_action = action_rot90(lr_flip_action)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:22:36.919614600Z",
|
||||
"start_time": "2024-06-06T11:22:36.860704600Z"
|
||||
}
|
||||
},
|
||||
"id": "55e802e0dbcc9c7f"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
|
||||
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jnp.rot90(a, k=2)"
|
||||
"state"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:12:48.186719Z",
|
||||
"start_time": "2024-06-06T11:12:48.151161900Z"
|
||||
"end_time": "2024-06-09T12:10:34.103124Z",
|
||||
"start_time": "2024-06-09T12:10:34.096124300Z"
|
||||
}
|
||||
},
|
||||
"id": "16f8de3cadaa257a"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# flip left-right\n",
|
||||
"jnp.fliplr(a)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:14:28.668195300Z",
|
||||
"start_time": "2024-06-06T11:14:28.631570500Z"
|
||||
}
|
||||
},
|
||||
"id": "1fffa4e597ab5732"
|
||||
"id": "6852e4e58b81dd9"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -194,7 +115,7 @@
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "ca53c916dcff12ae"
|
||||
"id": "97a50322218a0427"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user