{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-06-02T11:14:55.056050100Z", "start_time": "2024-06-02T11:14:55.008909900Z" } }, "outputs": [ { "data": { "text/plain": "State ({'kan_initial_grids': Array([-1. , -0.5, 0. , 0.5, 1. ], dtype=float32)})" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pipeline import Pipeline\n", "from algorithm.neat import *\n", "from algorithm.neat.gene.node.kan_node import KANNode\n", "from algorithm.neat.gene.conn.bspline import BSplineConn\n", "from problem.func_fit import XOR3d\n", "from tensorneat.utils import ACT\n", "\n", "import jax, jax.numpy as jnp\n", "\n", "genome = DefaultGenome(\n", " num_inputs=3,\n", " num_outputs=1,\n", " max_nodes=5,\n", " max_conns=10,\n", " node_gene=KANNode(),\n", " conn_gene=BSplineConn(),\n", " output_transform=ACT.sigmoid, # the activation function for output node\n", " mutation=DefaultMutation(\n", " node_add=0.1,\n", " conn_add=0.1,\n", " node_delete=0.05,\n", " conn_delete=0.05,\n", " ),\n", ")\n", "state = genome.setup()\n", "state" ] }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "(Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[ 0. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.04929435, -1.2567043 , 1.1369427 ,\n 0.6141437 , 1.4434636 , 0.24439397, 0.77281904],\n [ 1. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 0.90565056, 1.4197341 , 0.82603943,\n 1.164936 , -0.74349356, 0.9511131 , -1.5443964 ],\n [ 2. , 4. , -1. , -0.5 , 0. ,\n 0.5 , 1. , 1.7152852 , -1.6385511 , 1.0964565 ,\n 0.6741095 , 1.4752939 , -0.3695403 , -0.5071054 ],\n [ 4. , 3. , -1. , -0.5 , 0. ,\n 0.5 , 1. , -1.2653785 , -1.2907758 , 0.6196416 ,\n -0.8124694 , -0.7498491 , -1.582707 , -0.04516089],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan],\n [ nan, nan, nan, nan, nan,\n nan, nan, nan, nan, nan,\n nan, nan, nan, nan]], dtype=float32, weak_type=True))" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "randkey = jax.random.key(0)\n", "nodes, conns = genome.initialize(state, randkey)\n", "nodes, conns" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-02T11:15:00.563905700Z", "start_time": "2024-06-02T11:14:58.394859200Z" } }, "id": "825037f59b1e2ab5" }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[0.],\n [1.],\n [2.],\n [3.],\n [4.]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n -1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n -0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n 0. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n 0.5 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.5 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n 1. ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 1. ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.04929435],\n [ nan, nan, nan, nan,\n 0.90565056],\n [ nan, nan, nan, nan,\n 1.7152852 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2653785 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n -1.2567043 ],\n [ nan, nan, nan, nan,\n 1.4197341 ],\n [ nan, nan, nan, nan,\n -1.6385511 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.2907758 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.1369427 ],\n [ nan, nan, nan, nan,\n 0.82603943],\n [ nan, nan, nan, nan,\n 1.0964565 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, 0.6196416 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.6141437 ],\n [ nan, nan, nan, nan,\n 1.164936 ],\n [ nan, nan, nan, nan,\n 0.6741095 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.8124694 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 1.4434636 ],\n [ nan, nan, nan, nan,\n -0.74349356],\n [ nan, nan, nan, nan,\n 1.4752939 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.7498491 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.24439397],\n [ nan, nan, nan, nan,\n 0.9511131 ],\n [ nan, nan, nan, nan,\n -0.3695403 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -1.582707 ,\n nan]],\n \n [[ nan, nan, nan, nan,\n 0.77281904],\n [ nan, nan, nan, nan,\n -1.5443964 ],\n [ nan, nan, nan, nan,\n -0.5071054 ],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.04516089,\n nan]]], dtype=float32, weak_type=True))" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed = genome.transform(state, nodes, conns)\n", "transformed" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-02T11:15:59.432329500Z", "start_time": "2024-06-02T11:15:58.667824700Z" } }, "id": "946ffb375548130f" }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "data": { "text/plain": "Array([nan], dtype=float32, weak_type=True)" }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res = genome.forward(state, jnp.array([1, 1, 1]), transformed)\n", "res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-02T11:17:08.398283400Z", "start_time": "2024-06-02T11:17:08.009319200Z" } }, "id": "9c5b0e1428868f61" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "6f8739dee0d50371" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }