Remove printing for useless nodes when generating python code or latex formula
This commit is contained in:
@@ -94,7 +94,28 @@ def topological_sort_python(
|
||||
|
||||
return topo_order, topo_layer
|
||||
|
||||
|
||||
def find_useful_nodes(
|
||||
nodes: Union[Set[int], List[int]],
|
||||
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||
output_idx: Set[int],
|
||||
) -> Set[int]:
|
||||
"""
|
||||
Find all useful nodes (really contribute to outputs)
|
||||
"""
|
||||
useful_nodes = set()
|
||||
useful_nodes = useful_nodes | output_idx
|
||||
while True:
|
||||
aux = set()
|
||||
for in_, out in conns:
|
||||
if out in useful_nodes and in_ not in useful_nodes:
|
||||
aux.add(in_)
|
||||
if len(aux) == 0: # no new nodes
|
||||
break
|
||||
else:
|
||||
useful_nodes = useful_nodes | aux
|
||||
# print(f"All nodes cnt={len(nodes)}, useful nodes cnt={len(useful_nodes)}")
|
||||
return useful_nodes
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
|
||||
@@ -43,11 +43,14 @@ def replace_variable_names(expression, mode):
|
||||
return expression_str
|
||||
|
||||
|
||||
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
|
||||
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = "\\begin{align}\n"
|
||||
|
||||
for i in topo_order[input_cnt: ]:
|
||||
# do not add node that does not contribute to output; useful nodes may be nan and cause error
|
||||
if i not in useful_nodes:
|
||||
continue
|
||||
symbol = symbols[i]
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
@@ -58,7 +61,7 @@ def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_expr
|
||||
return res
|
||||
|
||||
|
||||
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
|
||||
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = ""
|
||||
if hidden_cnt > 0:
|
||||
@@ -66,6 +69,9 @@ def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exp
|
||||
res += f"o = np.zeros({output_cnt})\n"
|
||||
|
||||
for i in topo_order[input_cnt: ]:
|
||||
# do not add node that does not contribute to output; useful nodes may be nan and cause error
|
||||
if i not in useful_nodes:
|
||||
continue
|
||||
symbol = symbols[i]
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
|
||||
@@ -13,6 +13,7 @@ from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
||||
from tensorneat.common import (
|
||||
topological_sort,
|
||||
topological_sort_python,
|
||||
find_useful_nodes,
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
ACT,
|
||||
@@ -131,6 +132,11 @@ class DefaultGenome(BaseGenome):
|
||||
)
|
||||
network["topo_order"] = topo_order
|
||||
network["topo_layers"] = topo_layers
|
||||
network["useful_nodes"] = find_useful_nodes(
|
||||
set(network["nodes"]),
|
||||
set(network["conns"]),
|
||||
set(self.output_idx)
|
||||
)
|
||||
return network
|
||||
|
||||
def sympy_func(
|
||||
@@ -251,7 +257,8 @@ class DefaultGenome(BaseGenome):
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,
|
||||
network["topo_order"]
|
||||
network["topo_order"],
|
||||
network["useful_nodes"]
|
||||
)
|
||||
|
||||
def visualize(
|
||||
|
||||
@@ -98,7 +98,7 @@ def re_cound_idx(nodes, conns, input_idx, output_idx):
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if np.isnan(key):
|
||||
continue
|
||||
if np.in1d(key, input_idx + output_idx):
|
||||
if np.isin(key, input_idx + output_idx):
|
||||
continue
|
||||
old2new[int(key)] = next_key
|
||||
next_key += 1
|
||||
|
||||
Reference in New Issue
Block a user