update functions. Visualize, Interpretable and with evox

This commit is contained in:
root
2024-07-12 04:35:22 +08:00
parent 5fc63fdaf1
commit 0d6e7477bf
32 changed files with 207 additions and 427 deletions

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from . import BaseNode
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
return {
"idx": idx,
"bias": bias,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{node_dict['idx']}_b")
bias = sp.symbols(f"n_{nd['idx']}_b")
z = convert_to_sympy(nd["agg"])(inputs)
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
z = bias + z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(node_dict["act"])(z)
return z, {bias: nd["bias"]}
return z, {bias: node_dict["bias"]}

View File

@@ -11,7 +11,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from .base import BaseNode
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
"idx": idx,
"bias": bias,
"res": res,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs)
print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(nd["act"])(z)
return z, {bias: nd["bias"], res: nd["res"]}