From 69d73aab73d3749c9451538f8908c02b9ddb800c Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 13 Jun 2024 05:55:33 +0800 Subject: [PATCH] add backend="jax" to sympy module --- tensorneat/algorithm/neat/gene/conn/base.py | 2 +- .../algorithm/neat/gene/conn/default.py | 11 +- tensorneat/algorithm/neat/gene/node/base.py | 2 +- .../algorithm/neat/gene/node/default.py | 44 +++-- .../gene/node/default_without_response.py | 41 ++-- tensorneat/algorithm/neat/genome/base.py | 4 +- tensorneat/algorithm/neat/genome/default.py | 56 +++++- .../interpret_visualize/genome_sympy.ipynb | 186 +++++++++++------- .../interpret_visualize/genome_sympy.py | 4 + tensorneat/utils/__init__.py | 9 +- tensorneat/utils/activation/act_sympy.py | 52 ++--- tensorneat/utils/aggregation/agg_sympy.py | 10 +- 12 files changed, 254 insertions(+), 167 deletions(-) diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index e8ba377..da4d54a 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -39,5 +39,5 @@ class BaseConnGene(BaseGene): "out": int(out_idx), } - def sympy_func(self, state, conn_dict, inputs, precision=None): + def sympy_func(self, state, conn_dict, inputs): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 15ff4b7..ef94cb7 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -1,6 +1,7 @@ import jax.numpy as jnp import jax.random - +import numpy as np +import sympy as sp from utils import mutate_float from . import BaseConnGene @@ -81,12 +82,10 @@ class DefaultConnGene(BaseConnGene): return { "in": int(conn[0]), "out": int(conn[1]), - "weight": float(conn[2]), + "weight": np.array(conn[2], dtype=np.float32), } def sympy_func(self, state, conn_dict, inputs, precision=None): - weight = conn_dict["weight"] - if precision is not None: - weight = round(weight, precision) + weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w") - return inputs * weight + return inputs * weight, {weight: conn_dict["weight"]} diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index f4ab28e..30e324d 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -54,5 +54,5 @@ class BaseNodeGene(BaseGene): "idx": int(idx), } - def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None): + def sympy_func(self, state, node_dict, inputs, is_output_node=False): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 2b3d7c1..4ca695f 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -2,6 +2,7 @@ from typing import Tuple import numpy as np import jax, jax.numpy as jnp +import sympy as sp from utils import ( Act, @@ -160,33 +161,36 @@ class DefaultNodeGene(BaseNodeGene): def to_dict(self, state, node): idx, bias, res, agg, act = node + + idx = int(idx) + bias = np.array(bias, dtype=np.float32) + res = np.array(res, dtype=np.float32) + agg = int(agg) + act = int(act) + + if act == -1: + act_func = Act.identity + else: + act_func = self.activation_options[act] return { - "idx": int(idx), - "bias": float(bias), - "res": float(res), + "idx": idx, + "bias": bias, + "res": res, "agg": self.aggregation_options[int(agg)].__name__, - "act": self.activation_options[int(act)].__name__, + "act": act_func.__name__, } - def sympy_func( - self, state, node_dict, inputs, is_output_node=False, precision=None - ): + def sympy_func(self, state, node_dict, inputs, is_output_node=False): + nd = node_dict + bias = sp.symbols(f"n_{nd['idx']}_b") + res = sp.symbols(f"n_{nd['idx']}_r") - bias = node_dict["bias"] - res = node_dict["res"] - agg = node_dict["agg"] - act = node_dict["act"] - - if precision is not None: - bias = round(bias, precision) - res = round(res, precision) - - z = convert_to_sympy(agg)(inputs) + z = convert_to_sympy(nd["agg"])(inputs) z = bias + z * res if is_output_node: - return z + pass else: - z = convert_to_sympy(act)(z) + z = convert_to_sympy(nd["act"])(z) - return z + return z, {bias: nd["bias"], res: nd["res"]} diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py index a798987..ecc9214 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -1,7 +1,8 @@ from typing import Tuple import jax, jax.numpy as jnp - +import numpy as np +import sympy as sp from utils import ( Act, Agg, @@ -133,30 +134,36 @@ class NodeGeneWithoutResponse(BaseNodeGene): def to_dict(self, state, node): idx, bias, agg, act = node + + idx = int(idx) + + bias = np.array(bias, dtype=np.float32) + agg = int(agg) + act = int(act) + + if act == -1: + act_func = Act.identity + else: + act_func = self.activation_options[act] + return { - "idx": int(idx), - "bias": float(bias), + "idx": idx, + "bias": bias, "agg": self.aggregation_options[int(agg)].__name__, - "act": self.activation_options[int(act)].__name__, + "act": act_func.__name__, } - def sympy_func( - self, state, node_dict, inputs, is_output_node=False, precision=None - ): + def sympy_func(self, state, node_dict, inputs, is_output_node=False): + nd = node_dict - bias = node_dict["bias"] - agg = node_dict["agg"] - act = node_dict["act"] + bias = sp.symbols(f"n_{nd['idx']}_b") - if precision is not None: - bias = round(bias, precision) + z = convert_to_sympy(nd["agg"])(inputs) - z = convert_to_sympy(agg)(inputs) z = bias + z - if is_output_node: - return z + pass else: - z = convert_to_sympy(act)(z) + z = convert_to_sympy(nd["act"])(z) - return z + return z, {bias: nd["bias"]} diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index 71f2903..aa16672 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -164,7 +164,6 @@ class BaseGenome(StatefulBaseClass): continue cd = self.conn_gene.to_dict(state, conn) in_idx, out_idx = cd["in"], cd["out"] - del cd["in"], cd["out"] conn_dict[(in_idx, out_idx)] = cd return conn_dict @@ -176,7 +175,6 @@ class BaseGenome(StatefulBaseClass): continue nd = self.node_gene.to_dict(state, node) idx = nd["idx"] - del nd["idx"] node_dict[idx] = nd return node_dict @@ -192,7 +190,7 @@ class BaseGenome(StatefulBaseClass): def get_output_idx(self): return self.output_idx.tolist() - def sympy_func(self, state, network, precision=3): + def sympy_func(self, state, network, sympy_output_transform=None): raise NotImplementedError def visualize( diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 3790b66..f68f70b 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable import jax, jax.numpy as jnp @@ -12,7 +13,8 @@ from utils import ( set_node_attrs, set_conn_attrs, attach_with_inf, - FUNCS_MODULE, + SYMPY_FUNCS_MODULE_NP, + SYMPY_FUNCS_MODULE_JNP, ) from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -191,7 +193,16 @@ class DefaultGenome(BaseGenome): new_transformed, ) - def sympy_func(self, state, network, precision=3): + def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"): + + assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" + module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP + + if sympy_output_transform is None and self.output_transform is not None: + warnings.warn( + "genome.output_transform is not None but sympy_output_transform is None!" + ) + input_idx = self.get_input_idx() output_idx = self.get_output_idx() order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) @@ -206,6 +217,7 @@ class DefaultGenome(BaseGenome): nodes_exprs = {} + args_symbols = {} for i in order: if i in input_idx: @@ -215,20 +227,25 @@ class DefaultGenome(BaseGenome): node_inputs = [] for conn in in_conns: val_represent = symbols[conn[0]] - val = self.conn_gene.sympy_func( + # a_s -> args_symbols + val, a_s = self.conn_gene.sympy_func( state, network["conns"][conn], val_represent, - precision=precision, ) + args_symbols.update(a_s) node_inputs.append(val) - nodes_exprs[symbols[i]] = self.node_gene.sympy_func( + nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func( state, network["nodes"][i], node_inputs, is_output_node=(i in output_idx), - precision=precision, ) + args_symbols.update(a_s) + if i in output_idx and sympy_output_transform is not None: + nodes_exprs[symbols[i]] = sympy_output_transform( + nodes_exprs[symbols[i]] + ) input_symbols = [v for k, v in symbols.items() if k in input_idx] reduced_exprs = nodes_exprs.copy() @@ -236,10 +253,31 @@ class DefaultGenome(BaseGenome): reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs) output_exprs = [reduced_exprs[symbols[i]] for i in output_idx] + lambdify_output_funcs = [ - sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE]) + sp.lambdify( + input_symbols + list(args_symbols.keys()), + exprs, + modules=[backend, module], + ) for exprs in output_exprs ] - forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs] - return symbols, input_symbols, nodes_exprs, output_exprs, forward_func + fixed_args_output_funcs = [] + for i in range(len(output_idx)): + + def f(inputs, i=i): + return lambdify_output_funcs[i](*inputs, *args_symbols.values()) + + fixed_args_output_funcs.append(f) + + forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs] + + return ( + symbols, + args_symbols, + input_symbols, + nodes_exprs, + output_exprs, + forward_func, + ) diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb b/tensorneat/examples/interpret_visualize/genome_sympy.ipynb index f0f0981..3b47899 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb +++ b/tensorneat/examples/interpret_visualize/genome_sympy.ipynb @@ -11,13 +11,15 @@ "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", - "from utils import Act, Agg" + "from utils import Act, Agg\n", + "\n", + "import numpy as np" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:46.886073700Z", - "start_time": "2024-06-12T11:35:46.042288800Z" + "end_time": "2024-06-12T21:48:58.065855900Z", + "start_time": "2024-06-12T21:48:57.292767Z" } }, "id": "9531a569d9ecf774" @@ -29,8 +31,8 @@ "source": [ "genome = AdvanceInitialize(\n", " num_inputs=3,\n", - " num_outputs=1,\n", - " hidden_cnt=1,\n", + " num_outputs=3,\n", + " hidden_cnt=2,\n", " max_nodes=50,\n", " max_conns=500,\n", " node_gene=NodeGeneWithoutResponse(\n", @@ -38,7 +40,8 @@ " aggregation_default=Agg.sum,\n", " # activation_options=(Act.tanh,),\n", " aggregation_options=(Agg.sum,),\n", - " )\n", + " ),\n", + " output_transform=jnp.tanh,\n", ")\n", "\n", "state = genome.setup()\n", @@ -51,8 +54,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:52.274062400Z", - "start_time": "2024-06-12T11:35:46.892042200Z" + "end_time": "2024-06-12T21:49:03.858545Z", + "start_time": "2024-06-12T21:48:58.071859800Z" } }, "id": "4013c9f9d5472eb7" @@ -63,7 +66,7 @@ "outputs": [ { "data": { - "text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]" + "text/plain": "{'nodes': {0: {'idx': 0,\n 'bias': array(0.22059791, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 1: {'idx': 1,\n 'bias': array(0.7715081, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 2: {'idx': 2,\n 'bias': array(1.1184921, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 3: {'idx': 3,\n 'bias': array(0.6967973, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 4: {'idx': 4,\n 'bias': array(0.85948837, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 5: {'idx': 5,\n 'bias': array(0.19332138, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 6: {'idx': 6,\n 'bias': array(-0.31763914, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 7: {'idx': 7,\n 'bias': array(0.05656302, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'}},\n 'conns': {(0, 6): {'in': 0,\n 'out': 6,\n 'weight': array(1.6676894, dtype=float32)},\n (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},\n (1, 6): {'in': 1, 'out': 6, 'weight': array(0.10137014, dtype=float32)},\n (1, 7): {'in': 1, 'out': 7, 'weight': array(-0.12093307, dtype=float32)},\n (2, 6): {'in': 2, 'out': 6, 'weight': array(-1.8677292, dtype=float32)},\n (2, 7): {'in': 2, 'out': 7, 'weight': array(-0.4195783, dtype=float32)},\n (6, 3): {'in': 6, 'out': 3, 'weight': array(1.2615877, dtype=float32)},\n (6, 4): {'in': 6, 'out': 4, 'weight': array(-0.27593768, dtype=float32)},\n (6, 5): {'in': 6, 'out': 5, 'weight': array(-0.5819819, dtype=float32)},\n (7, 3): {'in': 7, 'out': 3, 'weight': array(0.59301573, dtype=float32)},\n (7, 4): {'in': 7, 'out': 4, 'weight': array(0.19493186, dtype=float32)},\n (7, 5): {'in': 7, 'out': 5, 'weight': array(0.18183969, dtype=float32)}}}" }, "execution_count": 3, "metadata": {}, @@ -71,89 +74,72 @@ } ], "source": [ - "import sympy as sp\n", - "\n", - "symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n", - "output_exprs" + "network" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:52.325161800Z", - "start_time": "2024-06-12T11:35:52.282008300Z" + "end_time": "2024-06-12T21:49:03.873543600Z", + "start_time": "2024-06-12T21:49:03.867543Z" + } + }, + "id": "188006cebb04847" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "import sympy as sp\n", + "\n", + "# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n", + "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n", + "symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T21:50:37.527882500Z", + "start_time": "2024-06-12T21:50:37.518559400Z" } }, "id": "addea793fc002900" }, { "cell_type": "code", - "execution_count": 4, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n" - ] - } - ], - "source": [ - "print(sp.latex(output_exprs[0]))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-12T11:35:52.341639700Z", - "start_time": "2024-06-12T11:35:52.323163700Z" - } - }, - "id": "967cb87e24373f77" - }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "collapsed": false - }, - "id": "88eee4db9eb857cd" - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "outputs": [ { "data": { - "text/plain": "[-0.7940936986556304]" + "text/plain": "(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))" }, - "execution_count": 5, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import numpy as np\n", - "random_inputs = np.random.randn(3)\n", - "res = forward_func(random_inputs)\n", - "res " + "random_inputs = np.random.randn(3).astype(np.float32)\n", + "random_inputs, random_inputs.dtype" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:52.342638Z", - "start_time": "2024-06-12T11:35:52.330160600Z" + "end_time": "2024-06-12T21:50:38.178769100Z", + "start_time": "2024-06-12T21:50:38.155744Z" } }, - "id": "c5581201d990ba1c" + "id": "3aa7c874f3a5743f" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "outputs": [ { "data": { - "text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)" + "text/plain": "Array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32, weak_type=True)" }, - "execution_count": 6, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -166,25 +152,89 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:53.273851900Z", - "start_time": "2024-06-12T11:35:52.384588600Z" + "end_time": "2024-06-12T21:50:48.747287900Z", + "start_time": "2024-06-12T21:50:48.560675400Z" } }, "id": "fe3449a5bc688bc3" }, { "cell_type": "code", - "execution_count": 6, - "outputs": [], - "source": [], + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "(array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32),\n array([ 0.9743453 , 0.57646036, -0.3080282 ], dtype=float32),\n array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32))" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)\n", + "res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)\n", + "res = np.array(genome.forward(state, transformed, random_inputs))\n", + "res1, res2, res" + ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-06-12T11:35:53.274854100Z", - "start_time": "2024-06-12T11:35:53.265856700Z" + "end_time": "2024-06-12T21:51:15.098948600Z", + "start_time": "2024-06-12T21:51:14.908948500Z" } }, - "id": "174c7dc3d9499f95" + "id": "a874d434509f1092" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": "(array([ True, True, True]), array([ True, False, True]))" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res1 == res, res2 == res" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T21:51:25.857465200Z", + "start_time": "2024-06-12T21:51:25.851465300Z" + } + }, + "id": "d226e5bd6e2d44d6" + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [ + { + "data": { + "text/plain": "array([False, False, True])" + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-12T21:00:19.851215800Z", + "start_time": "2024-06-12T21:00:19.836443700Z" + } + }, + "id": "2a36ce6afc59ee8a" } ], "metadata": { diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.py b/tensorneat/examples/interpret_visualize/genome_sympy.py index 1304d4a..268998c 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.py +++ b/tensorneat/examples/interpret_visualize/genome_sympy.py @@ -28,3 +28,7 @@ if __name__ == '__main__': print(genome.repr(state, nodes, conns)) print(network) + + res = genome.sympy_func(state, network, precision=3) + print(res) + diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index aa94939..b2f8df0 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -1,3 +1,5 @@ +import jax.numpy as jnp + from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL from .tools import * from .graph import * @@ -29,6 +31,7 @@ name2sympy = { "min": SympyMin, "maxabs": SympyMaxabs, "mean": SympyMean, + "clip": SympyClip, } @@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]): ) -FUNCS_MODULE = {} +SYMPY_FUNCS_MODULE_NP = {} +SYMPY_FUNCS_MODULE_JNP = {} for cls in name2sympy.values(): if hasattr(cls, "numerical_eval"): - FUNCS_MODULE[cls.__name__] = cls.numerical_eval + SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval + SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp) diff --git a/tensorneat/utils/activation/act_sympy.py b/tensorneat/utils/activation/act_sympy.py index b58eeb2..7cb043d 100644 --- a/tensorneat/utils/activation/act_sympy.py +++ b/tensorneat/utils/activation/act_sympy.py @@ -1,4 +1,3 @@ -from typing import Union import sympy as sp import numpy as np @@ -13,8 +12,8 @@ class SympyClip(sp.Function): return None @staticmethod - def numerical_eval(val, min_val, max_val): - return np.clip(val, min_val, max_val) + def numerical_eval(val, min_val, max_val, backend=np): + return backend.clip(val, min_val, max_val) def _sympystr(self, printer): return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})" @@ -32,9 +31,9 @@ class SympySigmoid(sp.Function): return None @staticmethod - def numerical_eval(z): - z = np.clip(5 * z, -10, 10) - return 1 / (1 + np.exp(-z)) + def numerical_eval(z, backend=np): + z = backend.clip(5 * z, -10, 10) + return 1 / (1 + backend.exp(-z)) def _sympystr(self, printer): return f"sigmoid({self.args[0]})" @@ -49,8 +48,8 @@ class SympyTanh(sp.Function): return sp.tanh(0.6 * z) @staticmethod - def numerical_eval(z): - return np.tanh(0.6 * z) + def numerical_eval(z, backend=np): + return backend.tanh(0.6 * z) class SympySin(sp.Function): @@ -59,8 +58,8 @@ class SympySin(sp.Function): return sp.sin(z) @staticmethod - def numerical_eval(z): - return np.sin(z) + def numerical_eval(z, backend=np): + return backend.sin(z) class SympyRelu(sp.Function): @@ -71,8 +70,8 @@ class SympyRelu(sp.Function): return None @staticmethod - def numerical_eval(z): - return np.maximum(z, 0) + def numerical_eval(z, backend=np): + return backend.maximum(z, 0) def _sympystr(self, printer): return f"relu({self.args[0]})" @@ -90,9 +89,9 @@ class SympyLelu(sp.Function): return None @staticmethod - def numerical_eval(z): + def numerical_eval(z, backend=np): leaky = 0.005 - return np.maximum(z, leaky * z) + return backend.maximum(z, leaky * z) def _sympystr(self, printer): return f"lelu({self.args[0]})" @@ -107,7 +106,7 @@ class SympyIdentity(sp.Function): return z @staticmethod - def numerical_eval(z): + def numerical_eval(z, backend=np): return z @@ -117,8 +116,8 @@ class SympyClamped(sp.Function): return SympyClip(z, -1, 1) @staticmethod - def numerical_eval(z): - return np.clip(z, -1, 1) + def numerical_eval(z, backend=np): + return backend.clip(z, -1, 1) class SympyInv(sp.Function): @@ -130,8 +129,8 @@ class SympyInv(sp.Function): return None @staticmethod - def numerical_eval(z): - z = np.maximum(z, 1e-7) + def numerical_eval(z, backend=np): + z = backend.maximum(z, 1e-7) return 1 / z def _sympystr(self, printer): @@ -150,9 +149,9 @@ class SympyLog(sp.Function): return None @staticmethod - def numerical_eval(z): - z = np.maximum(z, 1e-7) - return np.log(z) + def numerical_eval(z, backend=np): + z = backend.maximum(z, 1e-7) + return backend.log(z) def _sympystr(self, printer): return f"log({self.args[0]})" @@ -169,11 +168,6 @@ class SympyExp(sp.Function): return sp.exp(z) return None - @staticmethod - def numerical_eval(z): - z = np.clip(z, -10, 10) - return np.exp(z) - def _sympystr(self, printer): return f"exp({self.args[0]})" @@ -185,7 +179,3 @@ class SympyAbs(sp.Function): @classmethod def eval(cls, z): return sp.Abs(z) - - @staticmethod - def numerical_eval(z): - return np.abs(z) \ No newline at end of file diff --git a/tensorneat/utils/aggregation/agg_sympy.py b/tensorneat/utils/aggregation/agg_sympy.py index 8dc3807..065dfce 100644 --- a/tensorneat/utils/aggregation/agg_sympy.py +++ b/tensorneat/utils/aggregation/agg_sympy.py @@ -1,3 +1,4 @@ +import numpy as np import sympy as sp @@ -51,15 +52,6 @@ class SympyMedian(sp.Function): return None - @staticmethod - def numerical_eval(args): - sorted_args = sorted(args) - n = len(sorted_args) - if n % 2 == 1: - return sorted_args[n // 2] - else: - return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2 - def _sympystr(self, printer): return f"median({', '.join(map(str, self.args))})"