Remove printing for useless nodes when generating python code or latex formula

This commit is contained in:
wls2002
2025-04-16 10:43:38 +08:00
parent 5fa5e81c72
commit ffc1f09be0
4 changed files with 39 additions and 5 deletions

View File

@@ -94,7 +94,28 @@ def topological_sort_python(
return topo_order, topo_layer
def find_useful_nodes(
nodes: Union[Set[int], List[int]],
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
output_idx: Set[int],
) -> Set[int]:
"""
Find all useful nodes (really contribute to outputs)
"""
useful_nodes = set()
useful_nodes = useful_nodes | output_idx
while True:
aux = set()
for in_, out in conns:
if out in useful_nodes and in_ not in useful_nodes:
aux.add(in_)
if len(aux) == 0: # no new nodes
break
else:
useful_nodes = useful_nodes | aux
# print(f"All nodes cnt={len(nodes)}, useful nodes cnt={len(useful_nodes)}")
return useful_nodes
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""

View File

@@ -43,11 +43,14 @@ def replace_variable_names(expression, mode):
return expression_str
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
res = "\\begin{align}\n"
for i in topo_order[input_cnt: ]:
# do not add node that does not contribute to output; useful nodes may be nan and cause error
if i not in useful_nodes:
continue
symbol = symbols[i]
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 2)
@@ -58,7 +61,7 @@ def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_expr
return res
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
res = ""
if hidden_cnt > 0:
@@ -66,6 +69,9 @@ def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exp
res += f"o = np.zeros({output_cnt})\n"
for i in topo_order[input_cnt: ]:
# do not add node that does not contribute to output; useful nodes may be nan and cause error
if i not in useful_nodes:
continue
symbol = symbols[i]
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 6)

View File

@@ -13,6 +13,7 @@ from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
from tensorneat.common import (
topological_sort,
topological_sort_python,
find_useful_nodes,
I_INF,
attach_with_inf,
ACT,
@@ -131,6 +132,11 @@ class DefaultGenome(BaseGenome):
)
network["topo_order"] = topo_order
network["topo_layers"] = topo_layers
network["useful_nodes"] = find_useful_nodes(
set(network["nodes"]),
set(network["conns"]),
set(self.output_idx)
)
return network
def sympy_func(
@@ -251,7 +257,8 @@ class DefaultGenome(BaseGenome):
nodes_exprs,
output_exprs,
forward_func,
network["topo_order"]
network["topo_order"],
network["useful_nodes"]
)
def visualize(

View File

@@ -98,7 +98,7 @@ def re_cound_idx(nodes, conns, input_idx, output_idx):
for i, key in enumerate(nodes[:, 0]):
if np.isnan(key):
continue
if np.in1d(key, input_idx + output_idx):
if np.isin(key, input_idx + output_idx):
continue
old2new[int(key)] = next_key
next_key += 1