Files
tensorneat-mend/test/test_kan.ipynb
2024-07-10 16:58:58 +08:00

165 lines
11 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-02T11:14:55.056050100Z",
"start_time": "2024-06-02T11:14:55.008909900Z"
}
},
"outputs": [
{
"data": {
"text/plain": "State ({'kan_initial_grids': Array([-1. , -0.5, 0. , 0.5, 1. ], dtype=float32)})"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pipeline import Pipeline\n",
"from algorithm.neat import *\n",
"from algorithm.neat.gene.node.kan_node import KANNode\n",
"from algorithm.neat.gene.conn.bspline import BSplineConn\n",
"from problem.func_fit import XOR3d\n",
"from tensorneat.utils import Act\n",
"\n",
"import jax, jax.numpy as jnp\n",
"\n",
"genome = DefaultGenome(\n",
" num_inputs=3,\n",
" num_outputs=1,\n",
" max_nodes=5,\n",
" max_conns=10,\n",
" node_gene=KANNode(),\n",
" conn_gene=BSplineConn(),\n",
" output_transform=Act.sigmoid, # the activation function for output node\n",
" mutation=DefaultMutation(\n",
" node_add=0.1,\n",
" conn_add=0.1,\n",
" node_delete=0.05,\n",
" conn_delete=0.05,\n",
" ),\n",
")\n",
"state = genome.setup()\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "(Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[ 0. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.04929435, -1.2567043 , 1.1369427 ,\n 0.6141437 , 1.4434636 , 0.24439397, 0.77281904],\n [ 1. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.90565056, 1.4197341 , 0.82603943,\n 1.164936 , -0.74349356, 0.9511131 , -1.5443964 ],\n [ 2. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 1.7152852 , -1.6385511 , 1.0964565 ,\n 0.6741095 , 1.4752939 , -0.3695403 , -0.5071054 ],\n [ 4. , 3. , -1. , -0.5 , 0. ,\n 0.5 , 1. , -1.2653785 , -1.2907758 , 0.6196416 ,\n -0.8124694 , -0.7498491 , -1.582707 , -0.04516089],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan]], dtype=float32, weak_type=True))"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"randkey = jax.random.key(0)\n",
"nodes, conns = genome.initialize(state, randkey)\n",
"nodes, conns"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-02T11:15:00.563905700Z",
"start_time": "2024-06-02T11:14:58.394859200Z"
}
},
"id": "825037f59b1e2ab5"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.04929435],\n [ nan, nan, nan, nan,\n 0.90565056],\n [ nan, nan, nan, nan,\n 1.7152852 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2653785 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -1.2567043 ],\n [ nan, nan, nan, nan,\n 1.4197341 ],\n [ nan, nan, nan, nan,\n -1.6385511 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2907758 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.1369427 ],\n [ nan, nan, nan, nan,\n 0.82603943],\n [ nan, nan, nan, nan,\n 1.0964565 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.6196416 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.6141437 ],\n [ nan, nan, nan, nan,\n 1.164936 ],\n [ nan, nan, nan, nan,\n 0.6741095 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.8124694 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.4434636 ],\n [ nan, nan, nan, nan,\n -0.74349356],\n [ nan, nan, nan, nan,\n 1.4752939 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.7498491 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.24439397],\n [ nan, nan, nan, nan,\n 0.9511131 ],\n [ nan, nan, nan, nan,\n -0.3695403 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.582707 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.77281904],\n [ nan, nan, nan, nan,\n -1.5443964 ],\n [ nan, nan, nan, nan,\n -0.5071054 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.04516089,\n nan]]], dtype=float32, weak_type=True))"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed = genome.transform(state, nodes, conns)\n",
"transformed"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-02T11:15:59.432329500Z",
"start_time": "2024-06-02T11:15:58.667824700Z"
}
},
"id": "946ffb375548130f"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "Array([nan], dtype=float32, weak_type=True)"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res = genome.forward(state, jnp.array([1, 1, 1]), transformed)\n",
"res"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-02T11:17:08.398283400Z",
"start_time": "2024-06-02T11:17:08.009319200Z"
}
},
"id": "9c5b0e1428868f61"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "6f8739dee0d50371"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}