From aac9f4c3fbdfff94e3f93706fb68d2659e276fa9 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 13 Jun 2024 07:07:43 +0800 Subject: [PATCH] Try to reduce difference between sympy formula and network. Spend a whole night. Failed; I'll never try it anymore. --- .../algorithm/neat/gene/conn/default.py | 2 +- .../algorithm/neat/gene/node/default.py | 6 +- .../gene/node/default_without_response.py | 2 +- tensorneat/algorithm/neat/genome/default.py | 1 - .../interpret_visualize/genome_sympy.ipynb | 373 ++++++++++++++---- tensorneat/utils/aggregation/agg_jnp.py | 12 +- tensorneat/utils/aggregation/agg_sympy.py | 4 + 7 files changed, 313 insertions(+), 87 deletions(-) diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index ef94cb7..263b9ac 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -82,7 +82,7 @@ class DefaultConnGene(BaseConnGene): return { "in": int(conn[0]), "out": int(conn[1]), - "weight": np.array(conn[2], dtype=np.float32), + "weight": jnp.float32(conn[2]), } def sympy_func(self, state, conn_dict, inputs, precision=None): diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 4ca695f..ba5fee9 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -163,8 +163,8 @@ class DefaultNodeGene(BaseNodeGene): idx, bias, res, agg, act = node idx = int(idx) - bias = np.array(bias, dtype=np.float32) - res = np.array(res, dtype=np.float32) + bias = jnp.float32(bias) + res = jnp.float32(res) agg = int(agg) act = int(act) @@ -186,7 +186,7 @@ class DefaultNodeGene(BaseNodeGene): res = sp.symbols(f"n_{nd['idx']}_r") z = convert_to_sympy(nd["agg"])(inputs) - z = bias + z * res + z = bias + res * z if is_output_node: pass diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py index ecc9214..db2eba6 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -137,7 +137,7 @@ class NodeGeneWithoutResponse(BaseNodeGene): idx = int(idx) - bias = np.array(bias, dtype=np.float32) + bias = jnp.float32(bias) agg = int(agg) act = int(act) diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index f68f70b..2c10aa2 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -216,7 +216,6 @@ class DefaultGenome(BaseGenome): symbols[i] = sp.Symbol(f"h{i}") nodes_exprs = {} - args_symbols = {} for i in order: diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb b/tensorneat/examples/interpret_visualize/genome_sympy.ipynb index 3b47899..1c93aa0 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb +++ b/tensorneat/examples/interpret_visualize/genome_sympy.ipynb @@ -18,30 +18,31 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:48:58.065855900Z", - "start_time": "2024-06-12T21:48:57.292767Z" + "end_time": "2024-06-12T22:58:32.064076300Z", + "start_time": "2024-06-12T22:58:31.208435400Z" } }, "id": "9531a569d9ecf774" }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 32, "outputs": [], "source": [ "genome = AdvanceInitialize(\n", - " num_inputs=3,\n", - " num_outputs=3,\n", + " num_inputs=20,\n", + " num_outputs=1,\n", " hidden_cnt=2,\n", - " max_nodes=50,\n", - " max_conns=500,\n", + " max_nodes=30,\n", + " max_conns=50,\n", " node_gene=NodeGeneWithoutResponse(\n", - " # activation_default=Act.tanh,\n", + " activation_default= Act.identity,\n", " aggregation_default=Agg.sum,\n", - " # activation_options=(Act.tanh,),\n", + " # activation_options=(Act.tanh, Act.sigmoid, Act.identity, Act.clamped),\n", + " activation_options=( Act.identity, ),\n", " aggregation_options=(Agg.sum,),\n", " ),\n", - " output_transform=jnp.tanh,\n", + " # output_transform=jnp.tanh,\n", ")\n", "\n", "state = genome.setup()\n", @@ -54,92 +55,68 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:49:03.858545Z", - "start_time": "2024-06-12T21:48:58.071859800Z" + "end_time": "2024-06-12T23:06:06.346267500Z", + "start_time": "2024-06-12T23:06:06.121268500Z" } }, "id": "4013c9f9d5472eb7" }, { "cell_type": "code", - "execution_count": 3, - "outputs": [ - { - "data": { - "text/plain": "{'nodes': {0: {'idx': 0,\n 'bias': array(0.22059791, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 1: {'idx': 1,\n 'bias': array(0.7715081, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 2: {'idx': 2,\n 'bias': array(1.1184921, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 3: {'idx': 3,\n 'bias': array(0.6967973, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 4: {'idx': 4,\n 'bias': array(0.85948837, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 5: {'idx': 5,\n 'bias': array(0.19332138, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 6: {'idx': 6,\n 'bias': array(-0.31763914, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 7: {'idx': 7,\n 'bias': array(0.05656302, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'}},\n 'conns': {(0, 6): {'in': 0,\n 'out': 6,\n 'weight': array(1.6676894, dtype=float32)},\n (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},\n (1, 6): {'in': 1, 'out': 6, 'weight': array(0.10137014, dtype=float32)},\n (1, 7): {'in': 1, 'out': 7, 'weight': array(-0.12093307, dtype=float32)},\n (2, 6): {'in': 2, 'out': 6, 'weight': array(-1.8677292, dtype=float32)},\n (2, 7): {'in': 2, 'out': 7, 'weight': array(-0.4195783, dtype=float32)},\n (6, 3): {'in': 6, 'out': 3, 'weight': array(1.2615877, dtype=float32)},\n (6, 4): {'in': 6, 'out': 4, 'weight': array(-0.27593768, dtype=float32)},\n (6, 5): {'in': 6, 'out': 5, 'weight': array(-0.5819819, dtype=float32)},\n (7, 3): {'in': 7, 'out': 3, 'weight': array(0.59301573, dtype=float32)},\n (7, 4): {'in': 7, 'out': 4, 'weight': array(0.19493186, dtype=float32)},\n (7, 5): {'in': 7, 'out': 5, 'weight': array(0.18183969, dtype=float32)}}}" - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "network" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-12T21:49:03.873543600Z", - "start_time": "2024-06-12T21:49:03.867543Z" - } - }, - "id": "188006cebb04847" - }, - { - "cell_type": "code", - "execution_count": 11, + "execution_count": 33, "outputs": [], "source": [ "import sympy as sp\n", "\n", "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n", - "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n", - "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')\n" + "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n", + "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:50:37.527882500Z", - "start_time": "2024-06-12T21:50:37.518559400Z" + "end_time": "2024-06-12T23:06:06.698824100Z", + "start_time": "2024-06-12T23:06:06.688855300Z" } }, "id": "addea793fc002900" }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 38, "outputs": [ { "data": { - "text/plain": "(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))" + "text/plain": "(Array([-0.10080967, -2.373122 , -0.12224621, 1.0417817 , 0.26311624,\n -0.04573117, 0.5329444 , 1.9844177 , -0.5471916 , -3.0961084 ,\n 0.07978257, -1.0657575 , -1.6740963 , 1.2435746 , -0.5811825 ,\n 0.8970058 , -0.4379712 , 0.9084878 , -1.0984142 , 0.33063456], dtype=float32),\n dtype('float32'))" }, - "execution_count": 12, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "random_inputs = np.random.randn(3).astype(np.float32)\n", + "random_inputs = np.random.randn(20).astype(np.float32)\n", + "random_inputs = jax.device_put(random_inputs)\n", "random_inputs, random_inputs.dtype" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:50:38.178769100Z", - "start_time": "2024-06-12T21:50:38.155744Z" + "end_time": "2024-06-12T23:06:13.317466Z", + "start_time": "2024-06-12T23:06:13.251298500Z" } }, "id": "3aa7c874f3a5743f" }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 39, "outputs": [ { "data": { - "text/plain": "Array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32, weak_type=True)" + "text/plain": "Array([-3.769288], dtype=float32, weak_type=True)" }, - "execution_count": 13, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -152,72 +129,286 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:50:48.747287900Z", - "start_time": "2024-06-12T21:50:48.560675400Z" + "end_time": "2024-06-12T23:06:13.734130200Z", + "start_time": "2024-06-12T23:06:13.530130300Z" } }, "id": "fe3449a5bc688bc3" }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 40, "outputs": [ { "data": { - "text/plain": "(array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32),\n array([ 0.9743453 , 0.57646036, -0.3080282 ], dtype=float32),\n array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32))" + "text/plain": "(Array([-3.7692878], dtype=float32),\n Array([-3.769288], dtype=float32, weak_type=True))" }, - "execution_count": 14, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)\n", - "res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)\n", - "res = np.array(genome.forward(state, transformed, random_inputs))\n", - "res1, res2, res" + "res1 = jnp.array(jax_forward_func(random_inputs))\n", + "res2 = genome.forward(state, transformed, random_inputs)\n", + "res1, res2" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:51:15.098948600Z", - "start_time": "2024-06-12T21:51:14.908948500Z" + "end_time": "2024-06-12T23:06:14.069158700Z", + "start_time": "2024-06-12T23:06:13.857130700Z" } }, "id": "a874d434509f1092" }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 41, "outputs": [ { "data": { - "text/plain": "(array([ True, True, True]), array([ True, False, True]))" + "text/plain": "False" }, - "execution_count": 15, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "res1 == res, res2 == res" + "all(res1 == res)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:51:25.857465200Z", - "start_time": "2024-06-12T21:51:25.851465300Z" + "end_time": "2024-06-12T23:06:14.302366300Z", + "start_time": "2024-06-12T23:06:14.276336800Z" } }, "id": "d226e5bd6e2d44d6" }, + { + "cell_type": "code", + "execution_count": 73, + "outputs": [], + "source": [ + "random_inputs = np.random.randn(1000).astype(np.float32)\n", + "random_inputs = jax.device_put(random_inputs)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:13.519243800Z", + "start_time": "2024-06-12T22:36:13.502174700Z" + } + }, + "id": "90dd091f91df6fa2" + }, + { + "cell_type": "code", + "execution_count": 74, + "outputs": [], + "source": [ + "res1 = 1.243123123 + random_inputs * 1.12413243123123" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:13.898386100Z", + "start_time": "2024-06-12T22:36:13.886386800Z" + } + }, + "id": "b35d2c01f50071c2" + }, + { + "cell_type": "code", + "execution_count": 75, + "outputs": [], + "source": [ + "res2 = random_inputs * 1.12413243123123 + 1.243123123" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:14.245129600Z", + "start_time": "2024-06-12T22:36:14.220718100Z" + } + }, + "id": "e69752c00bd32361" + }, + { + "cell_type": "code", + "execution_count": 76, + "outputs": [ + { + "data": { + "text/plain": "Array([ True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True], dtype=bool)" + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1 == res2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:14.573137100Z", + "start_time": "2024-06-12T22:36:14.551029Z" + } + }, + "id": "78b9e00669100231" + }, + { + "cell_type": "code", + "execution_count": 77, + "outputs": [ + { + "data": { + "text/plain": "(Array(1251.1074, dtype=float32), Array(1251.1078, dtype=float32))" + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resres1 = 0.\n", + "for i in range(1000):\n", + " resres1 += res1[i]\n", + "resres2 = jnp.sum(res2)\n", + "resres1, resres2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:15.688800500Z", + "start_time": "2024-06-12T22:36:15.088955300Z" + } + }, + "id": "91d739b6b1520363" + }, + { + "cell_type": "code", + "execution_count": 78, + "outputs": [ + { + "data": { + "text/plain": "Array(False, dtype=bool)" + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resres1 == resres2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:36:16.917702Z", + "start_time": "2024-06-12T22:36:16.906916800Z" + } + }, + "id": "60502ba8f400bb60" + }, + { + "cell_type": "code", + "execution_count": 66, + "outputs": [ + { + "data": { + "text/plain": "(Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32),\n Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32))" + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1, res2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:35:50.915413100Z", + "start_time": "2024-06-12T22:35:50.868608Z" + } + }, + "id": "c1c092879ca33c3c" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [], + "source": [ + "real = 10\n", + "full = 50000\n", + "random_inputs = np.random.randn(real).astype(np.float32)\n", + "random_inputs = jax.device_put(random_inputs)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:55:02.540229100Z", + "start_time": "2024-06-12T22:55:02.530229200Z" + } + }, + "id": "5e1e7555937f1690" + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [], + "source": [ + "all_nans = jnp.full((full,), jnp.nan)\n", + "large = all_nans.at[:real].set(random_inputs)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:55:02.971772100Z", + "start_time": "2024-06-12T22:55:02.961150Z" + } + }, + "id": "5ea2c2051880ca1e" + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [ + { + "data": { + "text/plain": "Array(-5.8886395, dtype=float32)" + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1 = jnp.sum(large, where=~jnp.isnan(large))\n", + "res1" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:55:03.381951100Z", + "start_time": "2024-06-12T22:55:03.370952400Z" + } + }, + "id": "9b4baa99710fad1" + }, { "cell_type": "code", "execution_count": 23, "outputs": [ { "data": { - "text/plain": "array([False, False, True])" + "text/plain": "Array(-5.8886395, dtype=float32)" }, "execution_count": 23, "metadata": {}, @@ -225,16 +416,52 @@ } ], "source": [ - "np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000" + "res2 = jnp.sum(random_inputs)\n", + "res2" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T21:00:19.851215800Z", - "start_time": "2024-06-12T21:00:19.836443700Z" + "end_time": "2024-06-12T22:55:03.781919300Z", + "start_time": "2024-06-12T22:55:03.770892100Z" } }, - "id": "2a36ce6afc59ee8a" + "id": "9b2897759c90b7c5" + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "data": { + "text/plain": "Array(True, dtype=bool)" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1 == res2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T22:55:04.328871600Z", + "start_time": "2024-06-12T22:55:04.314871Z" + } + }, + "id": "e312213166610144" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "7e9ca9fa15e3c401" } ], "metadata": { diff --git a/tensorneat/utils/aggregation/agg_jnp.py b/tensorneat/utils/aggregation/agg_jnp.py index f604148..3359c6b 100644 --- a/tensorneat/utils/aggregation/agg_jnp.py +++ b/tensorneat/utils/aggregation/agg_jnp.py @@ -9,23 +9,19 @@ class Agg: @staticmethod def sum(z): - z = jnp.where(jnp.isnan(z), 0, z) - return jnp.sum(z, axis=0) + return jnp.sum(z, axis=0, where=~jnp.isnan(z)) @staticmethod def product(z): - z = jnp.where(jnp.isnan(z), 1, z) - return jnp.prod(z, axis=0) + return jnp.prod(z, axis=0, where=~jnp.isnan(z)) @staticmethod def max(z): - z = jnp.where(jnp.isnan(z), -jnp.inf, z) - return jnp.max(z, axis=0) + return jnp.max(z, axis=0, where=~jnp.isnan(z)) @staticmethod def min(z): - z = jnp.where(jnp.isnan(z), jnp.inf, z) - return jnp.min(z, axis=0) + return jnp.min(z, axis=0, where=~jnp.isnan(z)) @staticmethod def maxabs(z): diff --git a/tensorneat/utils/aggregation/agg_sympy.py b/tensorneat/utils/aggregation/agg_sympy.py index 065dfce..0890a49 100644 --- a/tensorneat/utils/aggregation/agg_sympy.py +++ b/tensorneat/utils/aggregation/agg_sympy.py @@ -7,6 +7,10 @@ class SympySum(sp.Function): def eval(cls, z): return sp.Add(*z) + @classmethod + def numerical_eval(cls, z, backend=np): + return backend.sum(z) + class SympyProduct(sp.Function): @classmethod