add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
@@ -31,3 +31,13 @@ class BaseConnGene(BaseGene):
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
|
||||
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
in_idx, out_idx = conn[:2]
|
||||
return {
|
||||
"in": int(in_idx),
|
||||
"out": int(out_idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -76,3 +76,17 @@ class DefaultConnGene(BaseConnGene):
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"weight": float(conn[2]),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
weight = conn_dict["weight"]
|
||||
if precision is not None:
|
||||
weight = round(weight, precision)
|
||||
|
||||
return inputs * weight
|
||||
|
||||
@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
|
||||
return "{}(idx={:<{idx_width}})".format(
|
||||
self.__class__.__name__, idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx = node[0]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_indices = np.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = jnp.arange(len(activation_options))
|
||||
self.activation_indices = np.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"res": float(res),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
res = node_dict["res"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
res = round(res, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z * res
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -2,7 +2,16 @@ from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return Agg.sum(inputs)
|
||||
|
||||
def repr(self, state, node, precision=2):
|
||||
idx = node[0]
|
||||
idx = int(idx)
|
||||
return "{}(idx: {})".format(self.__class__.__name__, idx)
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass
|
||||
from utils import State, StatefulBaseClass, topological_sort_python
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -155,3 +155,112 @@ class BaseGenome(StatefulBaseClass):
|
||||
@classmethod
|
||||
def valid_cnt(cls, arr):
|
||||
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||
|
||||
def get_conn_dict(self, state, conns):
|
||||
conns = jax.device_get(conns)
|
||||
conn_dict = {}
|
||||
for conn in conns:
|
||||
if np.isnan(conn[0]):
|
||||
continue
|
||||
cd = self.conn_gene.to_dict(state, conn)
|
||||
in_idx, out_idx = cd["in"], cd["out"]
|
||||
del cd["in"], cd["out"]
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
def get_node_dict(self, state, nodes):
|
||||
nodes = jax.device_get(nodes)
|
||||
node_dict = {}
|
||||
for node in nodes:
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
nd = self.node_gene.to_dict(state, node)
|
||||
idx = nd["idx"]
|
||||
del nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self.get_node_dict(state, nodes),
|
||||
"conns": self.get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
with_labels=True,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
unflatten_conns,
|
||||
topological_sort,
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
attach_with_inf,
|
||||
FUNCS_MODULE,
|
||||
)
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
@@ -188,3 +190,56 @@ class DefaultGenome(BaseGenome):
|
||||
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[i] = sp.Symbol(f"i{i}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i}")
|
||||
else: # hidden
|
||||
symbols[i] = sp.Symbol(f"h{i}")
|
||||
|
||||
nodes_exprs = {}
|
||||
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
for conn in in_conns:
|
||||
val_represent = symbols[conn[0]]
|
||||
val = self.conn_gene.sympy_func(
|
||||
state,
|
||||
network["conns"][conn],
|
||||
val_represent,
|
||||
precision=precision,
|
||||
)
|
||||
node_inputs.append(val)
|
||||
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
|
||||
state,
|
||||
network["nodes"][i],
|
||||
node_inputs,
|
||||
is_output_node=(i in output_idx),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
for i in order:
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
|
||||
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
||||
lambdify_output_funcs = [
|
||||
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
|
||||
for exprs in output_exprs
|
||||
]
|
||||
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
|
||||
|
||||
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func
|
||||
|
||||
@@ -84,3 +84,6 @@ class RecurrentGenome(BaseGenome):
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
||||
|
||||
Reference in New Issue
Block a user