add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,3 +7,7 @@ __pycache__/
|
||||
*.log
|
||||
.cache/
|
||||
.tmp/
|
||||
|
||||
# Ignore files named exactly 'tmp' or 'aux'
|
||||
tmp
|
||||
aux
|
||||
@@ -31,3 +31,13 @@ class BaseConnGene(BaseGene):
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
|
||||
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
in_idx, out_idx = conn[:2]
|
||||
return {
|
||||
"in": int(in_idx),
|
||||
"out": int(out_idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -76,3 +76,17 @@ class DefaultConnGene(BaseConnGene):
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"weight": float(conn[2]),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
weight = conn_dict["weight"]
|
||||
if precision is not None:
|
||||
weight = round(weight, precision)
|
||||
|
||||
return inputs * weight
|
||||
|
||||
@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
|
||||
return "{}(idx={:<{idx_width}})".format(
|
||||
self.__class__.__name__, idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx = node[0]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_indices = np.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = jnp.arange(len(activation_options))
|
||||
self.activation_indices = np.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"res": float(res),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
res = node_dict["res"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
res = round(res, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z * res
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -2,7 +2,16 @@ from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return Agg.sum(inputs)
|
||||
|
||||
def repr(self, state, node, precision=2):
|
||||
idx = node[0]
|
||||
idx = int(idx)
|
||||
return "{}(idx: {})".format(self.__class__.__name__, idx)
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass
|
||||
from utils import State, StatefulBaseClass, topological_sort_python
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -155,3 +155,112 @@ class BaseGenome(StatefulBaseClass):
|
||||
@classmethod
|
||||
def valid_cnt(cls, arr):
|
||||
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||
|
||||
def get_conn_dict(self, state, conns):
|
||||
conns = jax.device_get(conns)
|
||||
conn_dict = {}
|
||||
for conn in conns:
|
||||
if np.isnan(conn[0]):
|
||||
continue
|
||||
cd = self.conn_gene.to_dict(state, conn)
|
||||
in_idx, out_idx = cd["in"], cd["out"]
|
||||
del cd["in"], cd["out"]
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
def get_node_dict(self, state, nodes):
|
||||
nodes = jax.device_get(nodes)
|
||||
node_dict = {}
|
||||
for node in nodes:
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
nd = self.node_gene.to_dict(state, node)
|
||||
idx = nd["idx"]
|
||||
del nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self.get_node_dict(state, nodes),
|
||||
"conns": self.get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
with_labels=True,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
unflatten_conns,
|
||||
topological_sort,
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
attach_with_inf,
|
||||
FUNCS_MODULE,
|
||||
)
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
@@ -188,3 +190,56 @@ class DefaultGenome(BaseGenome):
|
||||
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[i] = sp.Symbol(f"i{i}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i}")
|
||||
else: # hidden
|
||||
symbols[i] = sp.Symbol(f"h{i}")
|
||||
|
||||
nodes_exprs = {}
|
||||
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
for conn in in_conns:
|
||||
val_represent = symbols[conn[0]]
|
||||
val = self.conn_gene.sympy_func(
|
||||
state,
|
||||
network["conns"][conn],
|
||||
val_represent,
|
||||
precision=precision,
|
||||
)
|
||||
node_inputs.append(val)
|
||||
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
|
||||
state,
|
||||
network["nodes"][i],
|
||||
node_inputs,
|
||||
is_output_node=(i in output_idx),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
for i in order:
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
|
||||
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
||||
lambdify_output_funcs = [
|
||||
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
|
||||
for exprs in output_exprs
|
||||
]
|
||||
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
|
||||
|
||||
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func
|
||||
|
||||
@@ -84,3 +84,6 @@ class RecurrentGenome(BaseGenome):
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
||||
|
||||
211
tensorneat/examples/interpret_visualize/genome_sympy.ipynb
Normal file
211
tensorneat/examples/interpret_visualize/genome_sympy.ipynb
Normal file
@@ -0,0 +1,211 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"from algorithm.neat import *\n",
|
||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"from utils.graph import topological_sort_python\n",
|
||||
"from utils import Act, Agg"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:46.886073700Z",
|
||||
"start_time": "2024-06-12T11:35:46.042288800Z"
|
||||
}
|
||||
},
|
||||
"id": "9531a569d9ecf774"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"genome = AdvanceInitialize(\n",
|
||||
" num_inputs=3,\n",
|
||||
" num_outputs=1,\n",
|
||||
" hidden_cnt=1,\n",
|
||||
" max_nodes=50,\n",
|
||||
" max_conns=500,\n",
|
||||
" node_gene=NodeGeneWithoutResponse(\n",
|
||||
" # activation_default=Act.tanh,\n",
|
||||
" aggregation_default=Agg.sum,\n",
|
||||
" # activation_options=(Act.tanh,),\n",
|
||||
" aggregation_options=(Agg.sum,),\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"state = genome.setup()\n",
|
||||
"\n",
|
||||
"randkey = jax.random.PRNGKey(42)\n",
|
||||
"nodes, conns = genome.initialize(state, randkey)\n",
|
||||
"\n",
|
||||
"network = genome.network_dict(state, nodes, conns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.274062400Z",
|
||||
"start_time": "2024-06-12T11:35:46.892042200Z"
|
||||
}
|
||||
},
|
||||
"id": "4013c9f9d5472eb7"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sympy as sp\n",
|
||||
"\n",
|
||||
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
|
||||
"output_exprs"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.325161800Z",
|
||||
"start_time": "2024-06-12T11:35:52.282008300Z"
|
||||
}
|
||||
},
|
||||
"id": "addea793fc002900"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sp.latex(output_exprs[0]))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.341639700Z",
|
||||
"start_time": "2024-06-12T11:35:52.323163700Z"
|
||||
}
|
||||
},
|
||||
"id": "967cb87e24373f77"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "88eee4db9eb857cd"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.7940936986556304]"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"random_inputs = np.random.randn(3)\n",
|
||||
"res = forward_func(random_inputs)\n",
|
||||
"res "
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.342638Z",
|
||||
"start_time": "2024-06-12T11:35:52.330160600Z"
|
||||
}
|
||||
},
|
||||
"id": "c5581201d990ba1c"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"transformed = genome.transform(state, nodes, conns)\n",
|
||||
"res = genome.forward(state, transformed, random_inputs)\n",
|
||||
"res"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.273851900Z",
|
||||
"start_time": "2024-06-12T11:35:52.384588600Z"
|
||||
}
|
||||
},
|
||||
"id": "fe3449a5bc688bc3"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.274854100Z",
|
||||
"start_time": "2024-06-12T11:35:53.265856700Z"
|
||||
}
|
||||
},
|
||||
"id": "174c7dc3d9499f95"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
30
tensorneat/examples/interpret_visualize/genome_sympy.py
Normal file
30
tensorneat/examples/interpret_visualize/genome_sympy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.advance import AdvanceInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
|
||||
if __name__ == '__main__':
|
||||
genome = AdvanceInitialize(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
hidden_cnt=8,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
)
|
||||
|
||||
state = genome.setup()
|
||||
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
print(set(network["nodes"]), set(network["conns"]))
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
print(order)
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
print(input_idx, output_idx)
|
||||
|
||||
print(genome.repr(state, nodes, conns))
|
||||
print(network)
|
||||
2455
tensorneat/examples/interpret_visualize/graph.svg
Normal file
2455
tensorneat/examples/interpret_visualize/graph.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 90 KiB |
191
tensorneat/examples/interpret_visualize/network.json
Normal file
191
tensorneat/examples/interpret_visualize/network.json
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"nodes": {
|
||||
"0": {
|
||||
"bias": 0.13710324466228485,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"1": {
|
||||
"bias": -1.4202250242233276,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"2": {
|
||||
"bias": -0.4653860926628113,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"3": {
|
||||
"bias": 0.5835710167884827,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"4": {
|
||||
"bias": 2.187405824661255,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"5": {
|
||||
"bias": 0.24963024258613586,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"6": {
|
||||
"bias": -0.966821551322937,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"7": {
|
||||
"bias": 0.4452081620693207,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"8": {
|
||||
"bias": -0.07293166220188141,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"9": {
|
||||
"bias": -0.1625899225473404,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"10": {
|
||||
"bias": -0.8576332330703735,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"11": {
|
||||
"bias": -0.18487468361854553,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"12": {
|
||||
"bias": 1.4335486888885498,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"13": {
|
||||
"bias": -0.8690621256828308,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"14": {
|
||||
"bias": -0.23014676570892334,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"15": {
|
||||
"bias": 0.7880322337150574,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"16": {
|
||||
"bias": -0.22258250415325165,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"17": {
|
||||
"bias": 0.2773352861404419,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"18": {
|
||||
"bias": -0.40279051661491394,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"19": {
|
||||
"bias": 1.092000961303711,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"20": {
|
||||
"bias": -0.4063087999820709,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"21": {
|
||||
"bias": 0.3895529806613922,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"22": {
|
||||
"bias": -0.18007506430149078,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"23": {
|
||||
"bias": -0.8112533092498779,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"24": {
|
||||
"bias": 0.2946726381778717,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"25": {
|
||||
"bias": -1.118497371673584,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"26": {
|
||||
"bias": 1.3674490451812744,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"27": {
|
||||
"bias": -1.6514816284179688,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"28": {
|
||||
"bias": 0.9440701603889465,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"29": {
|
||||
"bias": 1.564852237701416,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"30": {
|
||||
"bias": -0.5568665266036987,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
}
|
||||
},
|
||||
"conns": {
|
||||
|
||||
2455
tensorneat/examples/interpret_visualize/network.svg
Normal file
2455
tensorneat/examples/interpret_visualize/network.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 89 KiB |
103
tensorneat/examples/interpret_visualize/visualize_genome.ipynb
Normal file
103
tensorneat/examples/interpret_visualize/visualize_genome.ipynb
Normal file
File diff suppressed because one or more lines are too long
13
tensorneat/examples/interpret_visualize/visualize_genome.py
Normal file
13
tensorneat/examples/interpret_visualize/visualize_genome.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 创建一个空白的有向图
|
||||
G = nx.DiGraph()
|
||||
|
||||
# 添加边
|
||||
G.add_edge('A', 'B')
|
||||
G.add_edge('A', 'C')
|
||||
G.add_edge('B', 'C')
|
||||
G.add_edge('C', 'D')
|
||||
|
||||
# 绘制有向图
|
||||
@@ -2,19 +2,19 @@ import jax, jax.numpy as jnp
|
||||
import jax.random
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
|
||||
|
||||
def random_policy(state, params, obs):
|
||||
# key = jax.random.key(obs.sum())
|
||||
# actions = jax.random.normal(key, (4,))
|
||||
key = jax.random.key(obs.sum())
|
||||
actions = jax.random.normal(key, (4,))
|
||||
# actions = actions.at[2:].set(-9999)
|
||||
return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([1, 2, 3, 4])
|
||||
# return actions
|
||||
return actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = Jumanji_2048(
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
||||
)
|
||||
state = problem.setup()
|
||||
jit_evaluate = jax.jit(
|
||||
|
||||
@@ -78,36 +78,37 @@ if __name__ == "__main__":
|
||||
Act.identity,
|
||||
),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum,),
|
||||
aggregation_options=(Agg.sum, ),
|
||||
activation_replace_rate=0.02,
|
||||
aggregation_replace_rate=0.02,
|
||||
bias_mutate_rate=0.03,
|
||||
bias_init_std=0.5,
|
||||
bias_mutate_power=0.2,
|
||||
bias_mutate_power=0.02,
|
||||
bias_replace_rate=0.01,
|
||||
),
|
||||
conn_gene=DefaultConnGene(
|
||||
weight_mutate_rate=0.015,
|
||||
weight_replace_rate=0.003,
|
||||
weight_mutate_power=0.5,
|
||||
weight_replace_rate=0.03,
|
||||
weight_mutate_power=0.05,
|
||||
),
|
||||
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=5,
|
||||
survival_threshold=0.1,
|
||||
survival_threshold=0.01,
|
||||
max_stagnation=7,
|
||||
genome_elitism=3,
|
||||
compatibility_threshold=1.2,
|
||||
),
|
||||
),
|
||||
problem=Jumanji_2048(
|
||||
max_step=10000,
|
||||
repeat_times=10,
|
||||
guarantee_invalid_action=True,
|
||||
max_step=1000,
|
||||
repeat_times=50,
|
||||
# guarantee_invalid_action=True,
|
||||
guarantee_invalid_action=False,
|
||||
action_policy=action_policy,
|
||||
),
|
||||
generation_limit=1000,
|
||||
generation_limit=10000,
|
||||
fitness_target=13000,
|
||||
save_path="2048.npz",
|
||||
)
|
||||
|
||||
@@ -18,8 +18,7 @@ class Jumanji_2048(RLEnv):
|
||||
|
||||
###################################################################
|
||||
|
||||
action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
|
||||
action = (action - 1) / 15
|
||||
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
|
||||
|
||||
###################################################################
|
||||
if self.guarantee_invalid_action:
|
||||
|
||||
@@ -2,189 +2,110 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:55:39.434327400Z",
|
||||
"start_time": "2024-06-06T11:55:39.361327400Z"
|
||||
}
|
||||
},
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
|
||||
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"a = jnp.array([\n",
|
||||
" [1, 2],\n",
|
||||
" [3, 4]\n",
|
||||
"])\n",
|
||||
"def rot_boards(board):\n",
|
||||
" def rot(a, _):\n",
|
||||
" a = jnp.rot90(a)\n",
|
||||
" return a, a # carry, y\n",
|
||||
" \n",
|
||||
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
|
||||
" return boards\n",
|
||||
"a1 = rot_boards(a)\n",
|
||||
"a2 = rot_boards(a)\n",
|
||||
"\n",
|
||||
"a = jnp.concatenate([a1, a2], axis=0)\n",
|
||||
"a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = a.reshape(8, -1)\n",
|
||||
"a"
|
||||
"from algorithm.neat import *\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"genome = DefaultGenome(\n",
|
||||
" num_inputs=27,\n",
|
||||
" num_outputs=8,\n",
|
||||
" max_nodes=100,\n",
|
||||
" max_conns=200,\n",
|
||||
" node_gene=DefaultNodeGene(\n",
|
||||
" activation_options=(Act.tanh,),\n",
|
||||
" activation_default=Act.tanh,\n",
|
||||
" ),\n",
|
||||
" output_transform=Act.tanh,\n",
|
||||
")\n",
|
||||
"state = genome.setup()\n",
|
||||
"genome"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:55:31.121054800Z",
|
||||
"start_time": "2024-06-06T11:55:31.075517200Z"
|
||||
"end_time": "2024-06-09T12:08:22.569123400Z",
|
||||
"start_time": "2024-06-09T12:08:19.331863800Z"
|
||||
}
|
||||
},
|
||||
"id": "639cdecea840351d"
|
||||
"id": "b2b214a5454c4814"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state = state.register(data=jnp.zeros((1, 27)))\n",
|
||||
"# try to save the genome object\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"with open('genome.pkl', 'wb') as f:\n",
|
||||
" genome.__dict__[\"state\"] = state\n",
|
||||
" pickle.dump(genome, f)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:09:01.943445900Z",
|
||||
"start_time": "2024-06-09T12:09:01.919416Z"
|
||||
}
|
||||
},
|
||||
"id": "28348dfc458e8473"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
|
||||
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
|
||||
"def action_rot90(li):\n",
|
||||
" first = li[0]\n",
|
||||
" return li[1:] + [first]\n",
|
||||
"\n",
|
||||
"a = a\n",
|
||||
"rl_flip_a = jnp.fliplr(a)"
|
||||
"# try to load the genome object\n",
|
||||
"with open('genome.pkl', 'rb') as f:\n",
|
||||
" genome = pickle.load(f)\n",
|
||||
" state = genome.state\n",
|
||||
" del genome.__dict__[\"state\"]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:22:36.417287600Z",
|
||||
"start_time": "2024-06-06T11:22:36.414285500Z"
|
||||
"end_time": "2024-06-09T12:10:28.621539400Z",
|
||||
"start_time": "2024-06-09T12:10:28.612540100Z"
|
||||
}
|
||||
},
|
||||
"id": "92b75cd0e870a28c"
|
||||
"id": "c91be9fe3d2b5d5d"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[1 2]\n",
|
||||
" [3 4]] ['up', 'right', 'down', 'left']\n",
|
||||
"[[2 1]\n",
|
||||
" [4 3]] ['up', 'left', 'down', 'right']\n",
|
||||
"[[2 4]\n",
|
||||
" [1 3]] ['right', 'down', 'left', 'up']\n",
|
||||
"[[1 3]\n",
|
||||
" [2 4]] ['left', 'down', 'right', 'up']\n",
|
||||
"[[4 3]\n",
|
||||
" [2 1]] ['down', 'left', 'up', 'right']\n",
|
||||
"[[3 4]\n",
|
||||
" [1 2]] ['down', 'right', 'up', 'left']\n",
|
||||
"[[3 1]\n",
|
||||
" [4 2]] ['left', 'up', 'right', 'down']\n",
|
||||
"[[4 2]\n",
|
||||
" [3 1]] ['right', 'up', 'left', 'down']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in range(4):\n",
|
||||
" print(a, action)\n",
|
||||
" print(rl_flip_a, lr_flip_action)\n",
|
||||
" a = jnp.rot90(a)\n",
|
||||
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
|
||||
" action = action_rot90(action)\n",
|
||||
" lr_flip_action = action_rot90(lr_flip_action)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:22:36.919614600Z",
|
||||
"start_time": "2024-06-06T11:22:36.860704600Z"
|
||||
}
|
||||
},
|
||||
"id": "55e802e0dbcc9c7f"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
|
||||
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jnp.rot90(a, k=2)"
|
||||
"state"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:12:48.186719Z",
|
||||
"start_time": "2024-06-06T11:12:48.151161900Z"
|
||||
"end_time": "2024-06-09T12:10:34.103124Z",
|
||||
"start_time": "2024-06-09T12:10:34.096124300Z"
|
||||
}
|
||||
},
|
||||
"id": "16f8de3cadaa257a"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# flip left-right\n",
|
||||
"jnp.fliplr(a)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-06T11:14:28.668195300Z",
|
||||
"start_time": "2024-06-06T11:14:28.631570500Z"
|
||||
}
|
||||
},
|
||||
"id": "1fffa4e597ab5732"
|
||||
"id": "6852e4e58b81dd9"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -194,7 +115,7 @@
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "ca53c916dcff12ae"
|
||||
"id": "97a50322218a0427"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -1,6 +1,51 @@
|
||||
from .activation import Act, act_func, ACT_ALL
|
||||
from .aggregation import Agg, agg_func, AGG_ALL
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
|
||||
from .activation.act_jnp import Act, ACT_ALL, act_func
|
||||
from .aggregation.agg_sympy import *
|
||||
from .activation.act_sympy import *
|
||||
|
||||
from typing import Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
"identity": SympyIdentity,
|
||||
"clamped": SympyClamped,
|
||||
"inv": SympyInv,
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"sum": SympySum,
|
||||
"product": SympyProduct,
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
}
|
||||
|
||||
|
||||
def convert_to_sympy(func: Union[str, callable]):
|
||||
if isinstance(func, str):
|
||||
name = func
|
||||
else:
|
||||
name = func.__name__
|
||||
if name in name2sympy:
|
||||
return name2sympy[name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not convert to sympy! Function {name} not found in name2sympy"
|
||||
)
|
||||
|
||||
|
||||
FUNCS_MODULE = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
|
||||
|
||||
0
tensorneat/utils/activation/__init__.py
Normal file
0
tensorneat/utils/activation/__init__.py
Normal file
@@ -3,6 +3,10 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Act:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Act, name)
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
191
tensorneat/utils/activation/act_sympy.py
Normal file
191
tensorneat/utils/activation/act_sympy.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from typing import Union
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SympyClip(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, val, min_val, max_val):
|
||||
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
||||
return sp.Piecewise(
|
||||
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(val, min_val, max_val):
|
||||
return np.clip(val, min_val, max_val)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z, -10, 10)
|
||||
return 1 / (1 + sp.exp(-z))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + np.exp(-z))
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.tanh(0.6 * z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.tanh(0.6 * z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.sin(z)
|
||||
|
||||
|
||||
class SympyRelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
return sp.Piecewise((z, z > 0), (0, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.maximum(z, 0)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"relu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyLelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
leaky = 0.005
|
||||
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
leaky = 0.005
|
||||
return np.maximum(z, leaky * z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"lelu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyIdentity(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return z
|
||||
|
||||
|
||||
class SympyClamped(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympyClip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.clip(z, -1, 1)
|
||||
|
||||
|
||||
class SympyInv(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
||||
return 1 / z
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"1 / {self.args[0]}"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
|
||||
|
||||
|
||||
class SympyLog(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Max(z, 1e-7)
|
||||
return sp.log(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.maximum(z, 1e-7)
|
||||
return np.log(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"log({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyExp(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -10, 10)
|
||||
return sp.exp(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
z = np.clip(z, -10, 10)
|
||||
return np.exp(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"exp({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Abs(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z):
|
||||
return np.abs(z)
|
||||
0
tensorneat/utils/aggregation/__init__.py
Normal file
0
tensorneat/utils/aggregation/__init__.py
Normal file
@@ -3,6 +3,10 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Agg, name)
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
69
tensorneat/utils/aggregation/agg_sympy.py
Normal file
69
tensorneat/utils/aggregation/agg_sympy.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class SympySum(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Min(*z)
|
||||
|
||||
|
||||
class SympyMaxabs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z, key=sp.Abs)
|
||||
|
||||
|
||||
class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
@@ -5,6 +5,7 @@ Only used in feed-forward networks.
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
from typing import Tuple, Set, List, Union
|
||||
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
@@ -41,6 +42,60 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
def topological_sort_python(
|
||||
nodes: Union[Set[int], List[int]],
|
||||
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||
) -> Tuple[List[int], List[List[int]]]:
|
||||
# a python version of topological_sort, use python set to store nodes and conns
|
||||
# returns the topological order of the nodes and the topological layers
|
||||
# written by gpt4 :)
|
||||
|
||||
# Make a copy of the input nodes and connections
|
||||
nodes = nodes.copy()
|
||||
conns = conns.copy()
|
||||
|
||||
# Initialize the in-degree of each node to 0
|
||||
in_degree = {node: 0 for node in nodes}
|
||||
|
||||
# Compute the in-degree for each node
|
||||
for conn in conns:
|
||||
in_degree[conn[1]] += 1
|
||||
|
||||
topo_order = []
|
||||
topo_layer = []
|
||||
|
||||
# Find all nodes with in-degree 0
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
while zero_in_degree_nodes:
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
zero_in_degree_nodes = sorted(
|
||||
zero_in_degree_nodes
|
||||
) # make sure the topo_order is from small to large
|
||||
|
||||
topo_layer.append(zero_in_degree_nodes.copy())
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
topo_order.append(node)
|
||||
|
||||
# Iterate over all connections and reduce the in-degree of connected nodes
|
||||
for conn in list(conns):
|
||||
if conn[0] == node:
|
||||
in_degree[conn[1]] -= 1
|
||||
conns.remove(conn)
|
||||
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
# Check if there are still connections left indicating a cycle
|
||||
if conns or nodes:
|
||||
raise ValueError("Graph has at least one cycle, topological sort not possible")
|
||||
|
||||
return topo_order, topo_layer
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user