diff --git a/src/tensorneat/common/graph.py b/src/tensorneat/common/graph.py index f783b8e..c973f84 100644 --- a/src/tensorneat/common/graph.py +++ b/src/tensorneat/common/graph.py @@ -94,7 +94,28 @@ def topological_sort_python( return topo_order, topo_layer - +def find_useful_nodes( + nodes: Union[Set[int], List[int]], + conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]], + output_idx: Set[int], +) -> Set[int]: + """ + Find all useful nodes (really contribute to outputs) + """ + useful_nodes = set() + useful_nodes = useful_nodes | output_idx + while True: + aux = set() + for in_, out in conns: + if out in useful_nodes and in_ not in useful_nodes: + aux.add(in_) + if len(aux) == 0: # no new nodes + break + else: + useful_nodes = useful_nodes | aux + # print(f"All nodes cnt={len(nodes)}, useful nodes cnt={len(useful_nodes)}") + return useful_nodes + @jit def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: """ diff --git a/src/tensorneat/common/sympy_tools.py b/src/tensorneat/common/sympy_tools.py index 60c21b2..05c5879 100644 --- a/src/tensorneat/common/sympy_tools.py +++ b/src/tensorneat/common/sympy_tools.py @@ -43,11 +43,14 @@ def replace_variable_names(expression, mode): return expression_str -def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order): +def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes): input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs) res = "\\begin{align}\n" for i in topo_order[input_cnt: ]: + # do not add node that does not contribute to output; useful nodes may be nan and cause error + if i not in useful_nodes: + continue symbol = symbols[i] expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) rounded_expr = round_expr(expr, 2) @@ -58,7 +61,7 @@ def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_expr return res -def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order): +def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes): input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs) res = "" if hidden_cnt > 0: @@ -66,6 +69,9 @@ def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exp res += f"o = np.zeros({output_cnt})\n" for i in topo_order[input_cnt: ]: + # do not add node that does not contribute to output; useful nodes may be nan and cause error + if i not in useful_nodes: + continue symbol = symbols[i] expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols) rounded_expr = round_expr(expr, 6) diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index f4b51e5..948ce32 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -13,6 +13,7 @@ from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs from tensorneat.common import ( topological_sort, topological_sort_python, + find_useful_nodes, I_INF, attach_with_inf, ACT, @@ -131,6 +132,11 @@ class DefaultGenome(BaseGenome): ) network["topo_order"] = topo_order network["topo_layers"] = topo_layers + network["useful_nodes"] = find_useful_nodes( + set(network["nodes"]), + set(network["conns"]), + set(self.output_idx) + ) return network def sympy_func( @@ -251,7 +257,8 @@ class DefaultGenome(BaseGenome): nodes_exprs, output_exprs, forward_func, - network["topo_order"] + network["topo_order"], + network["useful_nodes"] ) def visualize( diff --git a/src/tensorneat/genome/utils.py b/src/tensorneat/genome/utils.py index dc67c3d..feac208 100644 --- a/src/tensorneat/genome/utils.py +++ b/src/tensorneat/genome/utils.py @@ -98,7 +98,7 @@ def re_cound_idx(nodes, conns, input_idx, output_idx): for i, key in enumerate(nodes[:, 0]): if np.isnan(key): continue - if np.in1d(key, input_idx + output_idx): + if np.isin(key, input_idx + output_idx): continue old2new[int(key)] = next_key next_key += 1