Remove printing for useless nodes when generating python code or latex formula
This commit is contained in:
@@ -94,6 +94,27 @@ def topological_sort_python(
|
|||||||
|
|
||||||
return topo_order, topo_layer
|
return topo_order, topo_layer
|
||||||
|
|
||||||
|
def find_useful_nodes(
|
||||||
|
nodes: Union[Set[int], List[int]],
|
||||||
|
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||||
|
output_idx: Set[int],
|
||||||
|
) -> Set[int]:
|
||||||
|
"""
|
||||||
|
Find all useful nodes (really contribute to outputs)
|
||||||
|
"""
|
||||||
|
useful_nodes = set()
|
||||||
|
useful_nodes = useful_nodes | output_idx
|
||||||
|
while True:
|
||||||
|
aux = set()
|
||||||
|
for in_, out in conns:
|
||||||
|
if out in useful_nodes and in_ not in useful_nodes:
|
||||||
|
aux.add(in_)
|
||||||
|
if len(aux) == 0: # no new nodes
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
useful_nodes = useful_nodes | aux
|
||||||
|
# print(f"All nodes cnt={len(nodes)}, useful nodes cnt={len(useful_nodes)}")
|
||||||
|
return useful_nodes
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||||
|
|||||||
@@ -43,11 +43,14 @@ def replace_variable_names(expression, mode):
|
|||||||
return expression_str
|
return expression_str
|
||||||
|
|
||||||
|
|
||||||
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
|
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
|
||||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||||
res = "\\begin{align}\n"
|
res = "\\begin{align}\n"
|
||||||
|
|
||||||
for i in topo_order[input_cnt: ]:
|
for i in topo_order[input_cnt: ]:
|
||||||
|
# do not add node that does not contribute to output; useful nodes may be nan and cause error
|
||||||
|
if i not in useful_nodes:
|
||||||
|
continue
|
||||||
symbol = symbols[i]
|
symbol = symbols[i]
|
||||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||||
rounded_expr = round_expr(expr, 2)
|
rounded_expr = round_expr(expr, 2)
|
||||||
@@ -58,7 +61,7 @@ def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_expr
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order):
|
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func, topo_order, useful_nodes):
|
||||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||||
res = ""
|
res = ""
|
||||||
if hidden_cnt > 0:
|
if hidden_cnt > 0:
|
||||||
@@ -66,6 +69,9 @@ def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exp
|
|||||||
res += f"o = np.zeros({output_cnt})\n"
|
res += f"o = np.zeros({output_cnt})\n"
|
||||||
|
|
||||||
for i in topo_order[input_cnt: ]:
|
for i in topo_order[input_cnt: ]:
|
||||||
|
# do not add node that does not contribute to output; useful nodes may be nan and cause error
|
||||||
|
if i not in useful_nodes:
|
||||||
|
continue
|
||||||
symbol = symbols[i]
|
symbol = symbols[i]
|
||||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||||
rounded_expr = round_expr(expr, 6)
|
rounded_expr = round_expr(expr, 6)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
|||||||
from tensorneat.common import (
|
from tensorneat.common import (
|
||||||
topological_sort,
|
topological_sort,
|
||||||
topological_sort_python,
|
topological_sort_python,
|
||||||
|
find_useful_nodes,
|
||||||
I_INF,
|
I_INF,
|
||||||
attach_with_inf,
|
attach_with_inf,
|
||||||
ACT,
|
ACT,
|
||||||
@@ -131,6 +132,11 @@ class DefaultGenome(BaseGenome):
|
|||||||
)
|
)
|
||||||
network["topo_order"] = topo_order
|
network["topo_order"] = topo_order
|
||||||
network["topo_layers"] = topo_layers
|
network["topo_layers"] = topo_layers
|
||||||
|
network["useful_nodes"] = find_useful_nodes(
|
||||||
|
set(network["nodes"]),
|
||||||
|
set(network["conns"]),
|
||||||
|
set(self.output_idx)
|
||||||
|
)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def sympy_func(
|
def sympy_func(
|
||||||
@@ -251,7 +257,8 @@ class DefaultGenome(BaseGenome):
|
|||||||
nodes_exprs,
|
nodes_exprs,
|
||||||
output_exprs,
|
output_exprs,
|
||||||
forward_func,
|
forward_func,
|
||||||
network["topo_order"]
|
network["topo_order"],
|
||||||
|
network["useful_nodes"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def re_cound_idx(nodes, conns, input_idx, output_idx):
|
|||||||
for i, key in enumerate(nodes[:, 0]):
|
for i, key in enumerate(nodes[:, 0]):
|
||||||
if np.isnan(key):
|
if np.isnan(key):
|
||||||
continue
|
continue
|
||||||
if np.in1d(key, input_idx + output_idx):
|
if np.isin(key, input_idx + output_idx):
|
||||||
continue
|
continue
|
||||||
old2new[int(key)] = next_key
|
old2new[int(key)] = next_key
|
||||||
next_key += 1
|
next_key += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user