update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -15,8 +15,8 @@ from tensorneat.common import (
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
ACT,
|
||||
AGG
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
hit_attrs = attach_with_inf(
|
||||
conns_attrs, conn_indices
|
||||
) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
is_output_node=jnp.isin(
|
||||
nodes[i, 0], self.output_idx
|
||||
), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
warnings.warn(
|
||||
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
|
||||
sp.lambdify(
|
||||
input_symbols + list(args_symbols.keys()),
|
||||
exprs,
|
||||
modules=[backend, module],
|
||||
modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
|
||||
)
|
||||
for exprs in output_exprs
|
||||
]
|
||||
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
color=("yellow", "white", "blue"),
|
||||
with_labels=False,
|
||||
edgecolors="k",
|
||||
arrowstyle="->",
|
||||
arrowsize=3,
|
||||
edge_color=(0.3, 0.3, 0.3),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
with_labels=with_labels,
|
||||
edgecolors=edgecolors,
|
||||
arrowstyle=arrowstyle,
|
||||
arrowsize=arrowsize,
|
||||
edge_color=edge_color,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from . import BaseNode
|
||||
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{node_dict['idx']}_b")
|
||||
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
|
||||
|
||||
z = bias + z
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(node_dict["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"]}
|
||||
return z, {bias: node_dict["bias"]}
|
||||
|
||||
@@ -11,7 +11,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from .base import BaseNode
|
||||
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
Reference in New Issue
Block a user