update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from . import BaseNode
|
||||
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{node_dict['idx']}_b")
|
||||
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
|
||||
|
||||
z = bias + z
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(node_dict["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"]}
|
||||
return z, {bias: node_dict["bias"]}
|
||||
|
||||
@@ -11,7 +11,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from .base import BaseNode
|
||||
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
Reference in New Issue
Block a user