update functions. Visualize, Interpretable and with evox

This commit is contained in:
root
2024-07-12 04:35:22 +08:00
parent 5fc63fdaf1
commit 0d6e7477bf
32 changed files with 207 additions and 427 deletions

View File

@@ -15,8 +15,8 @@ from tensorneat.common import (
topological_sort_python,
I_INF,
attach_with_inf,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
ACT,
AGG
)
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
hit_attrs = attach_with_inf(
conns_attrs, conn_indices
) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
is_output_node=jnp.isin(
nodes[i, 0], self.output_idx
), # nodes[0] -> the key of nodes
)
# set new value
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
sp.lambdify(
input_symbols + list(args_symbols.keys()),
exprs,
modules=[backend, module],
modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
)
for exprs in output_exprs
]
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
color=("yellow", "white", "blue"),
with_labels=False,
edgecolors="k",
arrowstyle="->",
arrowsize=3,
edge_color=(0.3, 0.3, 0.3),
save_path="network.svg",
save_dpi=800,
**kwargs,
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
with_labels=with_labels,
edgecolors=edgecolors,
arrowstyle=arrowstyle,
arrowsize=arrowsize,
edge_color=edge_color,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from . import BaseNode
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
return {
"idx": idx,
"bias": bias,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{node_dict['idx']}_b")
bias = sp.symbols(f"n_{nd['idx']}_b")
z = convert_to_sympy(nd["agg"])(inputs)
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
z = bias + z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(node_dict["act"])(z)
return z, {bias: nd["bias"]}
return z, {bias: node_dict["bias"]}

View File

@@ -11,7 +11,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from .base import BaseNode
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
"idx": idx,
"bias": bias,
"res": res,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs)
print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(nd["act"])(z)
return z, {bias: nd["bias"], res: nd["res"]}