{ "cells": [ { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import jax, jax.numpy as jnp\n", "\n", "from algorithm.neat import *\n", "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", "from tensorneat.utils import ACT, AGG\n", "\n", "import numpy as np" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:58:32.064076300Z", "start_time": "2024-06-12T22:58:31.208435400Z" } }, "id": "9531a569d9ecf774" }, { "cell_type": "code", "execution_count": 32, "outputs": [], "source": [ "genome = AdvanceInitialize(\n", " num_inputs=20,\n", " num_outputs=1,\n", " hidden_cnt=2,\n", " max_nodes=30,\n", " max_conns=50,\n", " node_gene=NodeGeneWithoutResponse(\n", " activation_default= ACT.identity,\n", " aggregation_default=AGG.sum,\n", " # activation_options=(ACT.tanh, ACT.sigmoid, ACT.identity, ACT.clamped),\n", " activation_options=( ACT.identity, ),\n", " aggregation_options=(AGG.sum,),\n", " ),\n", " # output_transform=jnp.tanh,\n", ")\n", "\n", "state = genome.setup()\n", "\n", "randkey = jax.random.PRNGKey(42)\n", "nodes, conns = genome.initialize(state, randkey)\n", "\n", "network = genome.network_dict(state, nodes, conns)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:06.346267500Z", "start_time": "2024-06-12T23:06:06.121268500Z" } }, "id": "4013c9f9d5472eb7" }, { "cell_type": "code", "execution_count": 33, "outputs": [], "source": [ "import sympy as sp\n", "\n", "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n", "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n", "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:06.698824100Z", "start_time": "2024-06-12T23:06:06.688855300Z" } }, "id": "addea793fc002900" }, { "cell_type": "code", "execution_count": 38, "outputs": [ { "data": { "text/plain": "(Array([-0.10080967, -2.373122 , -0.12224621, 1.0417817 , 0.26311624,\n -0.04573117, 0.5329444 , 1.9844177 , -0.5471916 , -3.0961084 ,\n 0.07978257, -1.0657575 , -1.6740963 , 1.2435746 , -0.5811825 ,\n 0.8970058 , -0.4379712 , 0.9084878 , -1.0984142 , 0.33063456], dtype=float32),\n dtype('float32'))" }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_inputs = np.random.randn(20).astype(np.float32)\n", "random_inputs = jax.device_put(random_inputs)\n", "random_inputs, random_inputs.dtype" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:13.317466Z", "start_time": "2024-06-12T23:06:13.251298500Z" } }, "id": "3aa7c874f3a5743f" }, { "cell_type": "code", "execution_count": 39, "outputs": [ { "data": { "text/plain": "Array([-3.769288], dtype=float32, weak_type=True)" }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed = genome.transform(state, nodes, conns)\n", "res = genome.forward(state, transformed, random_inputs)\n", "res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:13.734130200Z", "start_time": "2024-06-12T23:06:13.530130300Z" } }, "id": "fe3449a5bc688bc3" }, { "cell_type": "code", "execution_count": 40, "outputs": [ { "data": { "text/plain": "(Array([-3.7692878], dtype=float32),\n Array([-3.769288], dtype=float32, weak_type=True))" }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 = jnp.array(jax_forward_func(random_inputs))\n", "res2 = genome.forward(state, transformed, random_inputs)\n", "res1, res2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:14.069158700Z", "start_time": "2024-06-12T23:06:13.857130700Z" } }, "id": "a874d434509f1092" }, { "cell_type": "code", "execution_count": 41, "outputs": [ { "data": { "text/plain": "False" }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all(res1 == res)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T23:06:14.302366300Z", "start_time": "2024-06-12T23:06:14.276336800Z" } }, "id": "d226e5bd6e2d44d6" }, { "cell_type": "code", "execution_count": 73, "outputs": [], "source": [ "random_inputs = np.random.randn(1000).astype(np.float32)\n", "random_inputs = jax.device_put(random_inputs)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:13.519243800Z", "start_time": "2024-06-12T22:36:13.502174700Z" } }, "id": "90dd091f91df6fa2" }, { "cell_type": "code", "execution_count": 74, "outputs": [], "source": [ "res1 = 1.243123123 + random_inputs * 1.12413243123123" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:13.898386100Z", "start_time": "2024-06-12T22:36:13.886386800Z" } }, "id": "b35d2c01f50071c2" }, { "cell_type": "code", "execution_count": 75, "outputs": [], "source": [ "res2 = random_inputs * 1.12413243123123 + 1.243123123" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:14.245129600Z", "start_time": "2024-06-12T22:36:14.220718100Z" } }, "id": "e69752c00bd32361" }, { "cell_type": "code", "execution_count": 76, "outputs": [ { "data": { "text/plain": "Array([ True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True], dtype=bool)" }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 == res2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:14.573137100Z", "start_time": "2024-06-12T22:36:14.551029Z" } }, "id": "78b9e00669100231" }, { "cell_type": "code", "execution_count": 77, "outputs": [ { "data": { "text/plain": "(Array(1251.1074, dtype=float32), Array(1251.1078, dtype=float32))" }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "resres1 = 0.\n", "for i in range(1000):\n", " resres1 += res1[i]\n", "resres2 = jnp.sum(res2)\n", "resres1, resres2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:15.688800500Z", "start_time": "2024-06-12T22:36:15.088955300Z" } }, "id": "91d739b6b1520363" }, { "cell_type": "code", "execution_count": 78, "outputs": [ { "data": { "text/plain": "Array(False, dtype=bool)" }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "resres1 == resres2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:36:16.917702Z", "start_time": "2024-06-12T22:36:16.906916800Z" } }, "id": "60502ba8f400bb60" }, { "cell_type": "code", "execution_count": 66, "outputs": [ { "data": { "text/plain": "(Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32),\n Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32))" }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1, res2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:35:50.915413100Z", "start_time": "2024-06-12T22:35:50.868608Z" } }, "id": "c1c092879ca33c3c" }, { "cell_type": "code", "execution_count": 20, "outputs": [], "source": [ "real = 10\n", "full = 50000\n", "random_inputs = np.random.randn(real).astype(np.float32)\n", "random_inputs = jax.device_put(random_inputs)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:55:02.540229100Z", "start_time": "2024-06-12T22:55:02.530229200Z" } }, "id": "5e1e7555937f1690" }, { "cell_type": "code", "execution_count": 21, "outputs": [], "source": [ "all_nans = jnp.full((full,), jnp.nan)\n", "large = all_nans.at[:real].set(random_inputs)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:55:02.971772100Z", "start_time": "2024-06-12T22:55:02.961150Z" } }, "id": "5ea2c2051880ca1e" }, { "cell_type": "code", "execution_count": 22, "outputs": [ { "data": { "text/plain": "Array(-5.8886395, dtype=float32)" }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 = jnp.sum(large, where=~jnp.isnan(large))\n", "res1" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:55:03.381951100Z", "start_time": "2024-06-12T22:55:03.370952400Z" } }, "id": "9b4baa99710fad1" }, { "cell_type": "code", "execution_count": 23, "outputs": [ { "data": { "text/plain": "Array(-5.8886395, dtype=float32)" }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res2 = jnp.sum(random_inputs)\n", "res2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:55:03.781919300Z", "start_time": "2024-06-12T22:55:03.770892100Z" } }, "id": "9b2897759c90b7c5" }, { "cell_type": "code", "execution_count": 24, "outputs": [ { "data": { "text/plain": "Array(True, dtype=bool)" }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 == res2" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T22:55:04.328871600Z", "start_time": "2024-06-12T22:55:04.314871Z" } }, "id": "e312213166610144" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "7e9ca9fa15e3c401" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }