165 lines
11 KiB
Plaintext
165 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "initial_id",
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-02T11:14:55.056050100Z",
|
|
"start_time": "2024-06-02T11:14:55.008909900Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "State ({'kan_initial_grids': Array([-1. , -0.5, 0. , 0.5, 1. ], dtype=float32)})"
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from pipeline import Pipeline\n",
|
|
"from algorithm.neat import *\n",
|
|
"from algorithm.neat.gene.node.kan_node import KANNode\n",
|
|
"from algorithm.neat.gene.conn.bspline import BSplineConn\n",
|
|
"from problem.func_fit import XOR3d\n",
|
|
"from tensorneat.utils import ACT\n",
|
|
"\n",
|
|
"import jax, jax.numpy as jnp\n",
|
|
"\n",
|
|
"genome = DefaultGenome(\n",
|
|
" num_inputs=3,\n",
|
|
" num_outputs=1,\n",
|
|
" max_nodes=5,\n",
|
|
" max_conns=10,\n",
|
|
" node_gene=KANNode(),\n",
|
|
" conn_gene=BSplineConn(),\n",
|
|
" output_transform=ACT.sigmoid, # the activation function for output node\n",
|
|
" mutation=DefaultMutation(\n",
|
|
" node_add=0.1,\n",
|
|
" conn_add=0.1,\n",
|
|
" node_delete=0.05,\n",
|
|
" conn_delete=0.05,\n",
|
|
" ),\n",
|
|
")\n",
|
|
"state = genome.setup()\n",
|
|
"state"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[ 0. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.04929435, -1.2567043 , 1.1369427 ,\n 0.6141437 , 1.4434636 , 0.24439397, 0.77281904],\n [ 1. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.90565056, 1.4197341 , 0.82603943,\n 1.164936 , -0.74349356, 0.9511131 , -1.5443964 ],\n [ 2. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 1.7152852 , -1.6385511 , 1.0964565 ,\n 0.6741095 , 1.4752939 , -0.3695403 , -0.5071054 ],\n [ 4. , 3. , -1. , -0.5 , 0. ,\n 0.5 , 1. , -1.2653785 , -1.2907758 , 0.6196416 ,\n -0.8124694 , -0.7498491 , -1.582707 , -0.04516089],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan]], dtype=float32, weak_type=True))"
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"randkey = jax.random.key(0)\n",
|
|
"nodes, conns = genome.initialize(state, randkey)\n",
|
|
"nodes, conns"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-02T11:15:00.563905700Z",
|
|
"start_time": "2024-06-02T11:14:58.394859200Z"
|
|
}
|
|
},
|
|
"id": "825037f59b1e2ab5"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.04929435],\n [ nan, nan, nan, nan,\n 0.90565056],\n [ nan, nan, nan, nan,\n 1.7152852 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2653785 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -1.2567043 ],\n [ nan, nan, nan, nan,\n 1.4197341 ],\n [ nan, nan, nan, nan,\n -1.6385511 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2907758 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.1369427 ],\n [ nan, nan, nan, nan,\n 0.82603943],\n [ nan, nan, nan, nan,\n 1.0964565 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.6196416 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.6141437 ],\n [ nan, nan, nan, nan,\n 1.164936 ],\n [ nan, nan, nan, nan,\n 0.6741095 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.8124694 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.4434636 ],\n [ nan, nan, nan, nan,\n -0.74349356],\n [ nan, nan, nan, nan,\n 1.4752939 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.7498491 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.24439397],\n [ nan, nan, nan, nan,\n 0.9511131 ],\n [ nan, nan, nan, nan,\n -0.3695403 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.582707 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.77281904],\n [ nan, nan, nan, nan,\n -1.5443964 ],\n [ nan, nan, nan, nan,\n -0.5071054 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.04516089,\n nan]]], dtype=float32, weak_type=True))"
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"transformed = genome.transform(state, nodes, conns)\n",
|
|
"transformed"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-02T11:15:59.432329500Z",
|
|
"start_time": "2024-06-02T11:15:58.667824700Z"
|
|
}
|
|
},
|
|
"id": "946ffb375548130f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([nan], dtype=float32, weak_type=True)"
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res = genome.forward(state, jnp.array([1, 1, 1]), transformed)\n",
|
|
"res"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-02T11:17:08.398283400Z",
|
|
"start_time": "2024-06-02T11:17:08.009319200Z"
|
|
}
|
|
},
|
|
"id": "9c5b0e1428868f61"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "6f8739dee0d50371"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|