add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

View File

@@ -2,189 +2,110 @@
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:39.434327400Z",
"start_time": "2024-06-06T11:55:39.361327400Z"
}
},
"execution_count": 1,
"outputs": [
{
"data": {
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
},
"execution_count": 22,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"a = jnp.array([\n",
" [1, 2],\n",
" [3, 4]\n",
"])\n",
"def rot_boards(board):\n",
" def rot(a, _):\n",
" a = jnp.rot90(a)\n",
" return a, a # carry, y\n",
" \n",
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
" return boards\n",
"a1 = rot_boards(a)\n",
"a2 = rot_boards(a)\n",
"\n",
"a = jnp.concatenate([a1, a2], axis=0)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = a.reshape(8, -1)\n",
"a"
"from algorithm.neat import *\n",
"from utils import Act, Agg\n",
"genome = DefaultGenome(\n",
" num_inputs=27,\n",
" num_outputs=8,\n",
" max_nodes=100,\n",
" max_conns=200,\n",
" node_gene=DefaultNodeGene(\n",
" activation_options=(Act.tanh,),\n",
" activation_default=Act.tanh,\n",
" ),\n",
" output_transform=Act.tanh,\n",
")\n",
"state = genome.setup()\n",
"genome"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:31.121054800Z",
"start_time": "2024-06-06T11:55:31.075517200Z"
"end_time": "2024-06-09T12:08:22.569123400Z",
"start_time": "2024-06-09T12:08:19.331863800Z"
}
},
"id": "639cdecea840351d"
"id": "b2b214a5454c4814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"state = state.register(data=jnp.zeros((1, 27)))\n",
"# try to save the genome object\n",
"import pickle\n",
"\n",
"with open('genome.pkl', 'wb') as f:\n",
" genome.__dict__[\"state\"] = state\n",
" pickle.dump(genome, f)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:09:01.943445900Z",
"start_time": "2024-06-09T12:09:01.919416Z"
}
},
"id": "28348dfc458e8473"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
"def action_rot90(li):\n",
" first = li[0]\n",
" return li[1:] + [first]\n",
"\n",
"a = a\n",
"rl_flip_a = jnp.fliplr(a)"
"# try to load the genome object\n",
"with open('genome.pkl', 'rb') as f:\n",
" genome = pickle.load(f)\n",
" state = genome.state\n",
" del genome.__dict__[\"state\"]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.417287600Z",
"start_time": "2024-06-06T11:22:36.414285500Z"
"end_time": "2024-06-09T12:10:28.621539400Z",
"start_time": "2024-06-09T12:10:28.612540100Z"
}
},
"id": "92b75cd0e870a28c"
"id": "c91be9fe3d2b5d5d"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1 2]\n",
" [3 4]] ['up', 'right', 'down', 'left']\n",
"[[2 1]\n",
" [4 3]] ['up', 'left', 'down', 'right']\n",
"[[2 4]\n",
" [1 3]] ['right', 'down', 'left', 'up']\n",
"[[1 3]\n",
" [2 4]] ['left', 'down', 'right', 'up']\n",
"[[4 3]\n",
" [2 1]] ['down', 'left', 'up', 'right']\n",
"[[3 4]\n",
" [1 2]] ['down', 'right', 'up', 'left']\n",
"[[3 1]\n",
" [4 2]] ['left', 'up', 'right', 'down']\n",
"[[4 2]\n",
" [3 1]] ['right', 'up', 'left', 'down']\n"
]
}
],
"source": [
"for i in range(4):\n",
" print(a, action)\n",
" print(rl_flip_a, lr_flip_action)\n",
" a = jnp.rot90(a)\n",
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
" action = action_rot90(action)\n",
" lr_flip_action = action_rot90(lr_flip_action)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.919614600Z",
"start_time": "2024-06-06T11:22:36.860704600Z"
}
},
"id": "55e802e0dbcc9c7f"
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
},
"execution_count": 6,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.rot90(a, k=2)"
"state"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:12:48.186719Z",
"start_time": "2024-06-06T11:12:48.151161900Z"
"end_time": "2024-06-09T12:10:34.103124Z",
"start_time": "2024-06-09T12:10:34.096124300Z"
}
},
"id": "16f8de3cadaa257a"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# flip left-right\n",
"jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:14:28.668195300Z",
"start_time": "2024-06-06T11:14:28.631570500Z"
}
},
"id": "1fffa4e597ab5732"
"id": "6852e4e58b81dd9"
},
{
"cell_type": "code",
@@ -194,7 +115,7 @@
"metadata": {
"collapsed": false
},
"id": "ca53c916dcff12ae"
"id": "97a50322218a0427"
}
],
"metadata": {