{ "cells": [ { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import jax, jax.numpy as jnp\n", "\n", "from algorithm.neat import *\n", "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", "from utils import Act, Agg" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:46.886073700Z", "start_time": "2024-06-12T11:35:46.042288800Z" } }, "id": "9531a569d9ecf774" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "genome = AdvanceInitialize(\n", " num_inputs=3,\n", " num_outputs=1,\n", " hidden_cnt=1,\n", " max_nodes=50,\n", " max_conns=500,\n", " node_gene=NodeGeneWithoutResponse(\n", " # activation_default=Act.tanh,\n", " aggregation_default=Agg.sum,\n", " # activation_options=(Act.tanh,),\n", " aggregation_options=(Agg.sum,),\n", " )\n", ")\n", "\n", "state = genome.setup()\n", "\n", "randkey = jax.random.PRNGKey(42)\n", "nodes, conns = genome.initialize(state, randkey)\n", "\n", "network = genome.network_dict(state, nodes, conns)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:52.274062400Z", "start_time": "2024-06-12T11:35:46.892042200Z" } }, "id": "4013c9f9d5472eb7" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "data": { "text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sympy as sp\n", "\n", "symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n", "output_exprs" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:52.325161800Z", "start_time": "2024-06-12T11:35:52.282008300Z" } }, "id": "addea793fc002900" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n" ] } ], "source": [ "print(sp.latex(output_exprs[0]))" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:52.341639700Z", "start_time": "2024-06-12T11:35:52.323163700Z" } }, "id": "967cb87e24373f77" }, { "cell_type": "markdown", "source": [], "metadata": { "collapsed": false }, "id": "88eee4db9eb857cd" }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "[-0.7940936986556304]" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "random_inputs = np.random.randn(3)\n", "res = forward_func(random_inputs)\n", "res " ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:52.342638Z", "start_time": "2024-06-12T11:35:52.330160600Z" } }, "id": "c5581201d990ba1c" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed = genome.transform(state, nodes, conns)\n", "res = genome.forward(state, transformed, random_inputs)\n", "res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:53.273851900Z", "start_time": "2024-06-12T11:35:52.384588600Z" } }, "id": "fe3449a5bc688bc3" }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T11:35:53.274854100Z", "start_time": "2024-06-12T11:35:53.265856700Z" } }, "id": "174c7dc3d9499f95" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }