Files
tensorneat-mend/examples/interpret_visualize/genome_sympy.ipynb
2024-07-12 02:25:57 +08:00

489 lines
19 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"from algorithm.neat import *\n",
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"from utils.graph import topological_sort_python\n",
"from tensorneat.utils import ACT, AGG\n",
"\n",
"import numpy as np"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:58:32.064076300Z",
"start_time": "2024-06-12T22:58:31.208435400Z"
}
},
"id": "9531a569d9ecf774"
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [],
"source": [
"genome = AdvanceInitialize(\n",
" num_inputs=20,\n",
" num_outputs=1,\n",
" hidden_cnt=2,\n",
" max_nodes=30,\n",
" max_conns=50,\n",
" node_gene=NodeGeneWithoutResponse(\n",
" activation_default= ACT.identity,\n",
" aggregation_default=AGG.sum,\n",
" # activation_options=(ACT.tanh, ACT.sigmoid, ACT.identity, ACT.clamped),\n",
" activation_options=( ACT.identity, ),\n",
" aggregation_options=(AGG.sum,),\n",
" ),\n",
" # output_transform=jnp.tanh,\n",
")\n",
"\n",
"state = genome.setup()\n",
"\n",
"randkey = jax.random.PRNGKey(42)\n",
"nodes, conns = genome.initialize(state, randkey)\n",
"\n",
"network = genome.network_dict(state, nodes, conns)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:06.346267500Z",
"start_time": "2024-06-12T23:06:06.121268500Z"
}
},
"id": "4013c9f9d5472eb7"
},
{
"cell_type": "code",
"execution_count": 33,
"outputs": [],
"source": [
"import sympy as sp\n",
"\n",
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n",
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n",
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:06.698824100Z",
"start_time": "2024-06-12T23:06:06.688855300Z"
}
},
"id": "addea793fc002900"
},
{
"cell_type": "code",
"execution_count": 38,
"outputs": [
{
"data": {
"text/plain": "(Array([-0.10080967, -2.373122 , -0.12224621, 1.0417817 , 0.26311624,\n -0.04573117, 0.5329444 , 1.9844177 , -0.5471916 , -3.0961084 ,\n 0.07978257, -1.0657575 , -1.6740963 , 1.2435746 , -0.5811825 ,\n 0.8970058 , -0.4379712 , 0.9084878 , -1.0984142 , 0.33063456], dtype=float32),\n dtype('float32'))"
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random_inputs = np.random.randn(20).astype(np.float32)\n",
"random_inputs = jax.device_put(random_inputs)\n",
"random_inputs, random_inputs.dtype"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:13.317466Z",
"start_time": "2024-06-12T23:06:13.251298500Z"
}
},
"id": "3aa7c874f3a5743f"
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [
{
"data": {
"text/plain": "Array([-3.769288], dtype=float32, weak_type=True)"
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed = genome.transform(state, nodes, conns)\n",
"res = genome.forward(state, transformed, random_inputs)\n",
"res"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:13.734130200Z",
"start_time": "2024-06-12T23:06:13.530130300Z"
}
},
"id": "fe3449a5bc688bc3"
},
{
"cell_type": "code",
"execution_count": 40,
"outputs": [
{
"data": {
"text/plain": "(Array([-3.7692878], dtype=float32),\n Array([-3.769288], dtype=float32, weak_type=True))"
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res1 = jnp.array(jax_forward_func(random_inputs))\n",
"res2 = genome.forward(state, transformed, random_inputs)\n",
"res1, res2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:14.069158700Z",
"start_time": "2024-06-12T23:06:13.857130700Z"
}
},
"id": "a874d434509f1092"
},
{
"cell_type": "code",
"execution_count": 41,
"outputs": [
{
"data": {
"text/plain": "False"
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all(res1 == res)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T23:06:14.302366300Z",
"start_time": "2024-06-12T23:06:14.276336800Z"
}
},
"id": "d226e5bd6e2d44d6"
},
{
"cell_type": "code",
"execution_count": 73,
"outputs": [],
"source": [
"random_inputs = np.random.randn(1000).astype(np.float32)\n",
"random_inputs = jax.device_put(random_inputs)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:13.519243800Z",
"start_time": "2024-06-12T22:36:13.502174700Z"
}
},
"id": "90dd091f91df6fa2"
},
{
"cell_type": "code",
"execution_count": 74,
"outputs": [],
"source": [
"res1 = 1.243123123 + random_inputs * 1.12413243123123"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:13.898386100Z",
"start_time": "2024-06-12T22:36:13.886386800Z"
}
},
"id": "b35d2c01f50071c2"
},
{
"cell_type": "code",
"execution_count": 75,
"outputs": [],
"source": [
"res2 = random_inputs * 1.12413243123123 + 1.243123123"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:14.245129600Z",
"start_time": "2024-06-12T22:36:14.220718100Z"
}
},
"id": "e69752c00bd32361"
},
{
"cell_type": "code",
"execution_count": 76,
"outputs": [
{
"data": {
"text/plain": "Array([ True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True], dtype=bool)"
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res1 == res2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:14.573137100Z",
"start_time": "2024-06-12T22:36:14.551029Z"
}
},
"id": "78b9e00669100231"
},
{
"cell_type": "code",
"execution_count": 77,
"outputs": [
{
"data": {
"text/plain": "(Array(1251.1074, dtype=float32), Array(1251.1078, dtype=float32))"
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"resres1 = 0.\n",
"for i in range(1000):\n",
" resres1 += res1[i]\n",
"resres2 = jnp.sum(res2)\n",
"resres1, resres2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:15.688800500Z",
"start_time": "2024-06-12T22:36:15.088955300Z"
}
},
"id": "91d739b6b1520363"
},
{
"cell_type": "code",
"execution_count": 78,
"outputs": [
{
"data": {
"text/plain": "Array(False, dtype=bool)"
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"resres1 == resres2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:36:16.917702Z",
"start_time": "2024-06-12T22:36:16.906916800Z"
}
},
"id": "60502ba8f400bb60"
},
{
"cell_type": "code",
"execution_count": 66,
"outputs": [
{
"data": {
"text/plain": "(Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32),\n Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32))"
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res1, res2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:35:50.915413100Z",
"start_time": "2024-06-12T22:35:50.868608Z"
}
},
"id": "c1c092879ca33c3c"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"real = 10\n",
"full = 50000\n",
"random_inputs = np.random.randn(real).astype(np.float32)\n",
"random_inputs = jax.device_put(random_inputs)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:55:02.540229100Z",
"start_time": "2024-06-12T22:55:02.530229200Z"
}
},
"id": "5e1e7555937f1690"
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [],
"source": [
"all_nans = jnp.full((full,), jnp.nan)\n",
"large = all_nans.at[:real].set(random_inputs)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:55:02.971772100Z",
"start_time": "2024-06-12T22:55:02.961150Z"
}
},
"id": "5ea2c2051880ca1e"
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [
{
"data": {
"text/plain": "Array(-5.8886395, dtype=float32)"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res1 = jnp.sum(large, where=~jnp.isnan(large))\n",
"res1"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:55:03.381951100Z",
"start_time": "2024-06-12T22:55:03.370952400Z"
}
},
"id": "9b4baa99710fad1"
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [
{
"data": {
"text/plain": "Array(-5.8886395, dtype=float32)"
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res2 = jnp.sum(random_inputs)\n",
"res2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:55:03.781919300Z",
"start_time": "2024-06-12T22:55:03.770892100Z"
}
},
"id": "9b2897759c90b7c5"
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [
{
"data": {
"text/plain": "Array(True, dtype=bool)"
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res1 == res2"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T22:55:04.328871600Z",
"start_time": "2024-06-12T22:55:04.314871Z"
}
},
"id": "e312213166610144"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "7e9ca9fa15e3c401"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}