{ "cells": [ { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import jax, jax.numpy as jnp\n", "\n", "from algorithm.neat import *\n", "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", "from utils import Act, Agg\n", "\n", "import numpy as np" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:48:58.065855900Z", "start_time": "2024-06-12T21:48:57.292767Z" } }, "id": "9531a569d9ecf774" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "genome = AdvanceInitialize(\n", " num_inputs=3,\n", " num_outputs=3,\n", " hidden_cnt=2,\n", " max_nodes=50,\n", " max_conns=500,\n", " node_gene=NodeGeneWithoutResponse(\n", " # activation_default=Act.tanh,\n", " aggregation_default=Agg.sum,\n", " # activation_options=(Act.tanh,),\n", " aggregation_options=(Agg.sum,),\n", " ),\n", " output_transform=jnp.tanh,\n", ")\n", "\n", "state = genome.setup()\n", "\n", "randkey = jax.random.PRNGKey(42)\n", "nodes, conns = genome.initialize(state, randkey)\n", "\n", "network = genome.network_dict(state, nodes, conns)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:49:03.858545Z", "start_time": "2024-06-12T21:48:58.071859800Z" } }, "id": "4013c9f9d5472eb7" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "data": { "text/plain": "{'nodes': {0: {'idx': 0,\n 'bias': array(0.22059791, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 1: {'idx': 1,\n 'bias': array(0.7715081, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 2: {'idx': 2,\n 'bias': array(1.1184921, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 3: {'idx': 3,\n 'bias': array(0.6967973, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 4: {'idx': 4,\n 'bias': array(0.85948837, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 5: {'idx': 5,\n 'bias': array(0.19332138, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 6: {'idx': 6,\n 'bias': array(-0.31763914, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 7: {'idx': 7,\n 'bias': array(0.05656302, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'}},\n 'conns': {(0, 6): {'in': 0,\n 'out': 6,\n 'weight': array(1.6676894, dtype=float32)},\n (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},\n (1, 6): {'in': 1, 'out': 6, 'weight': array(0.10137014, dtype=float32)},\n (1, 7): {'in': 1, 'out': 7, 'weight': array(-0.12093307, dtype=float32)},\n (2, 6): {'in': 2, 'out': 6, 'weight': array(-1.8677292, dtype=float32)},\n (2, 7): {'in': 2, 'out': 7, 'weight': array(-0.4195783, dtype=float32)},\n (6, 3): {'in': 6, 'out': 3, 'weight': array(1.2615877, dtype=float32)},\n (6, 4): {'in': 6, 'out': 4, 'weight': array(-0.27593768, dtype=float32)},\n (6, 5): {'in': 6, 'out': 5, 'weight': array(-0.5819819, dtype=float32)},\n (7, 3): {'in': 7, 'out': 3, 'weight': array(0.59301573, dtype=float32)},\n (7, 4): {'in': 7, 'out': 4, 'weight': array(0.19493186, dtype=float32)},\n (7, 5): {'in': 7, 'out': 5, 'weight': array(0.18183969, dtype=float32)}}}" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "network" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:49:03.873543600Z", "start_time": "2024-06-12T21:49:03.867543Z" } }, "id": "188006cebb04847" }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "import sympy as sp\n", "\n", "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n", "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n", "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:50:37.527882500Z", "start_time": "2024-06-12T21:50:37.518559400Z" } }, "id": "addea793fc002900" }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "data": { "text/plain": "(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))" }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_inputs = np.random.randn(3).astype(np.float32)\n", "random_inputs, random_inputs.dtype" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:50:38.178769100Z", "start_time": "2024-06-12T21:50:38.155744Z" } }, "id": "3aa7c874f3a5743f" }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "data": { "text/plain": "Array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32, weak_type=True)" }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed = genome.transform(state, nodes, conns)\n", "res = genome.forward(state, transformed, random_inputs)\n", "res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:50:48.747287900Z", "start_time": "2024-06-12T21:50:48.560675400Z" } }, "id": "fe3449a5bc688bc3" }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "data": { "text/plain": "(array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32),\n array([ 0.9743453 , 0.57646036, -0.3080282 ], dtype=float32),\n array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32))" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)\n", "res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)\n", "res = np.array(genome.forward(state, transformed, random_inputs))\n", "res1, res2, res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:51:15.098948600Z", "start_time": "2024-06-12T21:51:14.908948500Z" } }, "id": "a874d434509f1092" }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "data": { "text/plain": "(array([ True, True, True]), array([ True, False, True]))" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res1 == res, res2 == res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:51:25.857465200Z", "start_time": "2024-06-12T21:51:25.851465300Z" } }, "id": "d226e5bd6e2d44d6" }, { "cell_type": "code", "execution_count": 23, "outputs": [ { "data": { "text/plain": "array([False, False, True])" }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-12T21:00:19.851215800Z", "start_time": "2024-06-12T21:00:19.836443700Z" } }, "id": "2a36ce6afc59ee8a" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }