From 55f1c626d3fdd53d6545ebd5b5551c8c3aa6c124 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 19 Mar 2025 11:56:06 +0800 Subject: [PATCH] fix bugs in to python code, see issue(https://github.com/EMI-Group/tensorneat/issues/24#issuecomment-2731438212) --- src/tensorneat/common/sympy_tools.py | 82 +++++++++++----------------- src/tensorneat/genome/default.py | 1 + 2 files changed, 34 insertions(+), 49 deletions(-) diff --git a/src/tensorneat/common/sympy_tools.py b/src/tensorneat/common/sympy_tools.py index 1898e57..60c21b2 100644 --- a/src/tensorneat/common/sympy_tools.py +++ b/src/tensorneat/common/sympy_tools.py @@ -22,70 +22,54 @@ def round_expr(expr, precision=2): return expr.xreplace({n: round(n, precision) for n in expr.atoms(sp.Number)}) -def replace_variable_names(expression): +def replace_variable_names(expression, mode): """ Transform sympy expression to a string with array index that can be used in python code. - For example, `o0` will be transformed to `o[0]`. + For example, `o0` will be transformed to `o[0]` in Python mode, + and `o0` will be transformed to LaTeX format using sympy's `latex()` in LaTeX mode. """ + assert mode in ["python", "latex"] expression_str = str(expression) - expression_str = re.sub(r"\bo(\d+)\b", r"o[\1]", expression_str) - expression_str = re.sub(r"\bh(\d+)\b", r"h[\1]", expression_str) - expression_str = re.sub(r"\bi(\d+)\b", r"i[\1]", expression_str) + + if mode == "python": + expression_str = re.sub(r"\bo(\d+)\b", r"o[\1]", expression_str) + expression_str = re.sub(r"\bh(\d+)\b", r"h[\1]", expression_str) + expression_str = re.sub(r"\bi(\d+)\b", r"i[\1]", expression_str) + else: # latex mode + expression_str = re.sub(r"\bo(\d+)\b", lambda m: f"o_{{{m.group(1)}}}", expression_str) + expression_str = re.sub(r"\bh(\d+)\b", lambda m: f"h_{{{m.group(1)}}}", expression_str) + expression_str = re.sub(r"\bi(\d+)\b", lambda m: f"i_{{{m.group(1)}}}", expression_str) + return expression_str -def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True): +def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order): input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs) res = "\\begin{align}\n" - if not use_hidden_nodes: - for i in range(output_cnt): - expr = output_exprs[i].subs(args_symbols) - rounded_expr = round_expr(expr, 2) - latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n" - res += latex_expr - else: - for i in range(hidden_cnt): - symbol = sp.symbols(f"h{i}") - expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) - rounded_expr = round_expr(expr, 2) - latex_expr = f"h_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n" - res += latex_expr - for i in range(output_cnt): - symbol = sp.symbols(f"o{i}") - expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) - rounded_expr = round_expr(expr, 2) - latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n" - res += latex_expr + for i in topo_order[input_cnt: ]: + symbol = symbols[i] + expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) + rounded_expr = round_expr(expr, 2) + latex_expr = f"{symbol} &= {sp.latex(rounded_expr)}\\newline\n" + latex_expr = replace_variable_names(latex_expr, "latex") + res += latex_expr res += "\\end{align}\n" return res -def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True): +def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order): input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs) res = "" - if not use_hidden_nodes: - # pre-allocate space - res += f"o = np.zeros({output_cnt})\n" - for i in range(output_cnt): - expr = output_exprs[i].subs(args_symbols) - rounded_expr = round_expr(expr, 6) - str_expr = f"o{i} = {rounded_expr}" - res += replace_variable_names(str_expr) + "\n" - else: - # pre-allocate space + if hidden_cnt > 0: res += f"h = np.zeros({hidden_cnt})\n" - res += f"o = np.zeros({output_cnt})\n" - for i in range(hidden_cnt): - symbol = sp.symbols(f"h{i}") - expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) - rounded_expr = round_expr(expr, 6) - str_expr = f"h{i} = {rounded_expr}" - res += replace_variable_names(str_expr) + "\n" - for i in range(output_cnt): - symbol = sp.symbols(f"o{i}") - expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) - rounded_expr = round_expr(expr, 6) - str_expr = f"o{i} = {rounded_expr}" - res += replace_variable_names(str_expr) + "\n" + res += f"o = np.zeros({output_cnt})\n" + + for i in topo_order[input_cnt: ]: + symbol = symbols[i] + expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) + rounded_expr = round_expr(expr, 6) + str_expr = f"{symbol} = {rounded_expr}" + res += replace_variable_names(str_expr, "python") + "\n" + return res \ No newline at end of file diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index e308a4e..f4b51e5 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -251,6 +251,7 @@ class DefaultGenome(BaseGenome): nodes_exprs, output_exprs, forward_func, + network["topo_order"] ) def visualize(