add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

4
.gitignore vendored
View File

@@ -7,3 +7,7 @@ __pycache__/
*.log
.cache/
.tmp/
# Ignore files named exactly 'tmp' or 'aux'
tmp
aux

View File

@@ -31,3 +31,13 @@ class BaseConnGene(BaseGene):
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
)
def to_dict(self, state, conn):
in_idx, out_idx = conn[:2]
return {
"in": int(in_idx),
"out": int(out_idx),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
raise NotImplementedError

View File

@@ -76,3 +76,17 @@ class DefaultConnGene(BaseConnGene):
idx_width=idx_width,
float_width=precision + 3,
)
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"weight": float(conn[2]),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
weight = conn_dict["weight"]
if precision is not None:
weight = round(weight, precision)
return inputs * weight

View File

@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width
)
def to_dict(self, state, node):
idx = node[0]
return {
"idx": int(idx),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
raise NotImplementedError

View File

@@ -1,8 +1,18 @@
from typing import Tuple
import numpy as np
import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, res, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"res": float(res),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
res = node_dict["res"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
res = round(res, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z * res
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -2,7 +2,16 @@ from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
float_width=precision + 3,
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs)
def repr(self, state, node, precision=2):
idx = node[0]
idx = int(idx)
return "{}(idx: {})".format(self.__class__.__name__, idx)

View File

@@ -2,7 +2,7 @@ import numpy as np
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
from utils import State, StatefulBaseClass
from utils import State, StatefulBaseClass, topological_sort_python
class BaseGenome(StatefulBaseClass):
@@ -155,3 +155,112 @@ class BaseGenome(StatefulBaseClass):
@classmethod
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))
def get_conn_dict(self, state, conns):
conns = jax.device_get(conns)
conn_dict = {}
for conn in conns:
if np.isnan(conn[0]):
continue
cd = self.conn_gene.to_dict(state, conn)
in_idx, out_idx = cd["in"], cd["out"]
del cd["in"], cd["out"]
conn_dict[(in_idx, out_idx)] = cd
return conn_dict
def get_node_dict(self, state, nodes):
nodes = jax.device_get(nodes)
node_dict = {}
for node in nodes:
if np.isnan(node[0]):
continue
nd = self.node_gene.to_dict(state, node)
idx = nd["idx"]
del nd["idx"]
node_dict[idx] = nd
return node_dict
def network_dict(self, state, nodes, conns):
return {
"nodes": self.get_node_dict(state, nodes),
"conns": self.get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def sympy_func(self, state, network, precision=3):
raise NotImplementedError
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
with_labels=True,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -1,17 +1,19 @@
from typing import Callable
import jax, jax.numpy as jnp
import sympy as sp
from utils import (
unflatten_conns,
topological_sort,
topological_sort_python,
I_INF,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
attach_with_inf,
FUNCS_MODULE,
)
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
@@ -188,3 +190,56 @@ class DefaultGenome(BaseGenome):
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
new_transformed,
)
def sympy_func(self, state, network, precision=3):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i}")
nodes_exprs = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[i]] = symbols[i]
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
val = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
precision=precision,
)
node_inputs.append(val)
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
precision=precision,
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
for exprs in output_exprs
]
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func

View File

@@ -84,3 +84,6 @@ class RecurrentGenome(BaseGenome):
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
def sympy_func(self, state, network, precision=3):
raise ValueError("Sympy function is not supported for Recurrent Network!")

View File

@@ -0,0 +1,211 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"from algorithm.neat import *\n",
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"from utils.graph import topological_sort_python\n",
"from utils import Act, Agg"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:46.886073700Z",
"start_time": "2024-06-12T11:35:46.042288800Z"
}
},
"id": "9531a569d9ecf774"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"genome = AdvanceInitialize(\n",
" num_inputs=3,\n",
" num_outputs=1,\n",
" hidden_cnt=1,\n",
" max_nodes=50,\n",
" max_conns=500,\n",
" node_gene=NodeGeneWithoutResponse(\n",
" # activation_default=Act.tanh,\n",
" aggregation_default=Agg.sum,\n",
" # activation_options=(Act.tanh,),\n",
" aggregation_options=(Agg.sum,),\n",
" )\n",
")\n",
"\n",
"state = genome.setup()\n",
"\n",
"randkey = jax.random.PRNGKey(42)\n",
"nodes, conns = genome.initialize(state, randkey)\n",
"\n",
"network = genome.network_dict(state, nodes, conns)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.274062400Z",
"start_time": "2024-06-12T11:35:46.892042200Z"
}
},
"id": "4013c9f9d5472eb7"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sympy as sp\n",
"\n",
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
"output_exprs"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.325161800Z",
"start_time": "2024-06-12T11:35:52.282008300Z"
}
},
"id": "addea793fc002900"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
]
}
],
"source": [
"print(sp.latex(output_exprs[0]))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.341639700Z",
"start_time": "2024-06-12T11:35:52.323163700Z"
}
},
"id": "967cb87e24373f77"
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
},
"id": "88eee4db9eb857cd"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "[-0.7940936986556304]"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"random_inputs = np.random.randn(3)\n",
"res = forward_func(random_inputs)\n",
"res "
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.342638Z",
"start_time": "2024-06-12T11:35:52.330160600Z"
}
},
"id": "c5581201d990ba1c"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed = genome.transform(state, nodes, conns)\n",
"res = genome.forward(state, transformed, random_inputs)\n",
"res"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:53.273851900Z",
"start_time": "2024-06-12T11:35:52.384588600Z"
}
},
"id": "fe3449a5bc688bc3"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:53.274854100Z",
"start_time": "2024-06-12T11:35:53.265856700Z"
}
},
"id": "174c7dc3d9499f95"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,30 @@
import jax, jax.numpy as jnp
from algorithm.neat import *
from algorithm.neat.genome.advance import AdvanceInitialize
from utils.graph import topological_sort_python
if __name__ == '__main__':
genome = AdvanceInitialize(
num_inputs=17,
num_outputs=6,
hidden_cnt=8,
max_nodes=50,
max_conns=500,
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
print(set(network["nodes"]), set(network["conns"]))
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
print(order)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
print(input_idx, output_idx)
print(genome.repr(state, nodes, conns))
print(network)

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 90 KiB

View File

@@ -0,0 +1,191 @@
{
"nodes": {
"0": {
"bias": 0.13710324466228485,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"1": {
"bias": -1.4202250242233276,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"2": {
"bias": -0.4653860926628113,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"3": {
"bias": 0.5835710167884827,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"4": {
"bias": 2.187405824661255,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"5": {
"bias": 0.24963024258613586,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"6": {
"bias": -0.966821551322937,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"7": {
"bias": 0.4452081620693207,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"8": {
"bias": -0.07293166220188141,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"9": {
"bias": -0.1625899225473404,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"10": {
"bias": -0.8576332330703735,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"11": {
"bias": -0.18487468361854553,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"12": {
"bias": 1.4335486888885498,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"13": {
"bias": -0.8690621256828308,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"14": {
"bias": -0.23014676570892334,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"15": {
"bias": 0.7880322337150574,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"16": {
"bias": -0.22258250415325165,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"17": {
"bias": 0.2773352861404419,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"18": {
"bias": -0.40279051661491394,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"19": {
"bias": 1.092000961303711,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"20": {
"bias": -0.4063087999820709,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"21": {
"bias": 0.3895529806613922,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"22": {
"bias": -0.18007506430149078,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"23": {
"bias": -0.8112533092498779,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"24": {
"bias": 0.2946726381778717,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"25": {
"bias": -1.118497371673584,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"26": {
"bias": 1.3674490451812744,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"27": {
"bias": -1.6514816284179688,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"28": {
"bias": 0.9440701603889465,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"29": {
"bias": 1.564852237701416,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"30": {
"bias": -0.5568665266036987,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
}
},
"conns": {

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 89 KiB

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
import networkx as nx
import matplotlib.pyplot as plt
# 创建一个空白的有向图
G = nx.DiGraph()
# 添加边
G.add_edge('A', 'B')
G.add_edge('A', 'C')
G.add_edge('B', 'C')
G.add_edge('C', 'D')
# 绘制有向图

View File

@@ -2,19 +2,19 @@ import jax, jax.numpy as jnp
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs):
# key = jax.random.key(obs.sum())
# actions = jax.random.normal(key, (4,))
key = jax.random.key(obs.sum())
actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999)
return jnp.array([4, 4, 0, 1])
# return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4])
# return actions
return actions
if __name__ == "__main__":
problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
)
state = problem.setup()
jit_evaluate = jax.jit(

View File

@@ -83,31 +83,32 @@ if __name__ == "__main__":
aggregation_replace_rate=0.02,
bias_mutate_rate=0.03,
bias_init_std=0.5,
bias_mutate_power=0.2,
bias_mutate_power=0.02,
bias_replace_rate=0.01,
),
conn_gene=DefaultConnGene(
weight_mutate_rate=0.015,
weight_replace_rate=0.003,
weight_mutate_power=0.5,
weight_replace_rate=0.03,
weight_mutate_power=0.05,
),
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
),
pop_size=1000,
species_size=5,
survival_threshold=0.1,
survival_threshold=0.01,
max_stagnation=7,
genome_elitism=3,
compatibility_threshold=1.2,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=10,
guarantee_invalid_action=True,
max_step=1000,
repeat_times=50,
# guarantee_invalid_action=True,
guarantee_invalid_action=False,
action_policy=action_policy,
),
generation_limit=1000,
generation_limit=10000,
fitness_target=13000,
save_path="2048.npz",
)

View File

@@ -18,8 +18,7 @@ class Jumanji_2048(RLEnv):
###################################################################
action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
action = (action - 1) / 15
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
###################################################################
if self.guarantee_invalid_action:

View File

@@ -2,189 +2,110 @@
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:39.434327400Z",
"start_time": "2024-06-06T11:55:39.361327400Z"
}
},
"execution_count": 1,
"outputs": [
{
"data": {
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
},
"execution_count": 22,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"a = jnp.array([\n",
" [1, 2],\n",
" [3, 4]\n",
"])\n",
"def rot_boards(board):\n",
" def rot(a, _):\n",
" a = jnp.rot90(a)\n",
" return a, a # carry, y\n",
" \n",
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
" return boards\n",
"a1 = rot_boards(a)\n",
"a2 = rot_boards(a)\n",
"\n",
"a = jnp.concatenate([a1, a2], axis=0)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = a.reshape(8, -1)\n",
"a"
"from algorithm.neat import *\n",
"from utils import Act, Agg\n",
"genome = DefaultGenome(\n",
" num_inputs=27,\n",
" num_outputs=8,\n",
" max_nodes=100,\n",
" max_conns=200,\n",
" node_gene=DefaultNodeGene(\n",
" activation_options=(Act.tanh,),\n",
" activation_default=Act.tanh,\n",
" ),\n",
" output_transform=Act.tanh,\n",
")\n",
"state = genome.setup()\n",
"genome"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:31.121054800Z",
"start_time": "2024-06-06T11:55:31.075517200Z"
"end_time": "2024-06-09T12:08:22.569123400Z",
"start_time": "2024-06-09T12:08:19.331863800Z"
}
},
"id": "639cdecea840351d"
"id": "b2b214a5454c4814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"state = state.register(data=jnp.zeros((1, 27)))\n",
"# try to save the genome object\n",
"import pickle\n",
"\n",
"with open('genome.pkl', 'wb') as f:\n",
" genome.__dict__[\"state\"] = state\n",
" pickle.dump(genome, f)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:09:01.943445900Z",
"start_time": "2024-06-09T12:09:01.919416Z"
}
},
"id": "28348dfc458e8473"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
"def action_rot90(li):\n",
" first = li[0]\n",
" return li[1:] + [first]\n",
"\n",
"a = a\n",
"rl_flip_a = jnp.fliplr(a)"
"# try to load the genome object\n",
"with open('genome.pkl', 'rb') as f:\n",
" genome = pickle.load(f)\n",
" state = genome.state\n",
" del genome.__dict__[\"state\"]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.417287600Z",
"start_time": "2024-06-06T11:22:36.414285500Z"
"end_time": "2024-06-09T12:10:28.621539400Z",
"start_time": "2024-06-09T12:10:28.612540100Z"
}
},
"id": "92b75cd0e870a28c"
"id": "c91be9fe3d2b5d5d"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1 2]\n",
" [3 4]] ['up', 'right', 'down', 'left']\n",
"[[2 1]\n",
" [4 3]] ['up', 'left', 'down', 'right']\n",
"[[2 4]\n",
" [1 3]] ['right', 'down', 'left', 'up']\n",
"[[1 3]\n",
" [2 4]] ['left', 'down', 'right', 'up']\n",
"[[4 3]\n",
" [2 1]] ['down', 'left', 'up', 'right']\n",
"[[3 4]\n",
" [1 2]] ['down', 'right', 'up', 'left']\n",
"[[3 1]\n",
" [4 2]] ['left', 'up', 'right', 'down']\n",
"[[4 2]\n",
" [3 1]] ['right', 'up', 'left', 'down']\n"
]
}
],
"source": [
"for i in range(4):\n",
" print(a, action)\n",
" print(rl_flip_a, lr_flip_action)\n",
" a = jnp.rot90(a)\n",
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
" action = action_rot90(action)\n",
" lr_flip_action = action_rot90(lr_flip_action)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.919614600Z",
"start_time": "2024-06-06T11:22:36.860704600Z"
}
},
"id": "55e802e0dbcc9c7f"
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
},
"execution_count": 6,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.rot90(a, k=2)"
"state"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:12:48.186719Z",
"start_time": "2024-06-06T11:12:48.151161900Z"
"end_time": "2024-06-09T12:10:34.103124Z",
"start_time": "2024-06-09T12:10:34.096124300Z"
}
},
"id": "16f8de3cadaa257a"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# flip left-right\n",
"jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:14:28.668195300Z",
"start_time": "2024-06-06T11:14:28.631570500Z"
}
},
"id": "1fffa4e597ab5732"
"id": "6852e4e58b81dd9"
},
{
"cell_type": "code",
@@ -194,7 +115,7 @@
"metadata": {
"collapsed": false
},
"id": "ca53c916dcff12ae"
"id": "97a50322218a0427"
}
],
"metadata": {

View File

@@ -1,6 +1,51 @@
from .activation import Act, act_func, ACT_ALL
from .aggregation import Agg, agg_func, AGG_ALL
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
from .state import State
from .stateful_class import StatefulBaseClass
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Union
name2sympy = {
"sigmoid": SympySigmoid,
"tanh": SympyTanh,
"sin": SympySin,
"relu": SympyRelu,
"lelu": SympyLelu,
"identity": SympyIdentity,
"clamped": SympyClamped,
"inv": SympyInv,
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"sum": SympySum,
"product": SympyProduct,
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
}
def convert_to_sympy(func: Union[str, callable]):
if isinstance(func, str):
name = func
else:
name = func.__name__
if name in name2sympy:
return name2sympy[name]
else:
raise ValueError(
f"Can not convert to sympy! Function {name} not found in name2sympy"
)
FUNCS_MODULE = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
FUNCS_MODULE[cls.__name__] = cls.numerical_eval

View File

View File

@@ -3,6 +3,10 @@ import jax.numpy as jnp
class Act:
@staticmethod
def name2func(name):
return getattr(Act, name)
@staticmethod
def sigmoid(z):
z = jnp.clip(5 * z, -10, 10)

View File

@@ -0,0 +1,191 @@
from typing import Union
import sympy as sp
import numpy as np
class SympyClip(sp.Function):
@classmethod
def eval(cls, val, min_val, max_val):
if val.is_Number and min_val.is_Number and max_val.is_Number:
return sp.Piecewise(
(min_val, val < min_val), (max_val, val > max_val), (val, True)
)
return None
@staticmethod
def numerical_eval(val, min_val, max_val):
return np.clip(val, min_val, max_val)
def _sympystr(self, printer):
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
def _latex(self, printer):
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
class SympySigmoid(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z, -10, 10)
return 1 / (1 + sp.exp(-z))
return None
@staticmethod
def numerical_eval(z):
z = np.clip(5 * z, -10, 10)
return 1 / (1 + np.exp(-z))
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
return sp.tanh(0.6 * z)
@staticmethod
def numerical_eval(z):
return np.tanh(0.6 * z)
class SympySin(sp.Function):
@classmethod
def eval(cls, z):
return sp.sin(z)
@staticmethod
def numerical_eval(z):
return np.sin(z)
class SympyRelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
return sp.Piecewise((z, z > 0), (0, True))
return None
@staticmethod
def numerical_eval(z):
return np.maximum(z, 0)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
class SympyLelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
leaky = 0.005
return sp.Piecewise((z, z > 0), (leaky * z, True))
return None
@staticmethod
def numerical_eval(z):
leaky = 0.005
return np.maximum(z, leaky * z)
def _sympystr(self, printer):
return f"lelu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
return z
@staticmethod
def numerical_eval(z):
return z
class SympyClamped(sp.Function):
@classmethod
def eval(cls, z):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z):
return np.clip(z, -1, 1)
class SympyInv(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
return 1 / z
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return 1 / z
def _sympystr(self, printer):
return f"1 / {self.args[0]}"
def _latex(self, printer):
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
class SympyLog(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Max(z, 1e-7)
return sp.log(z)
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def _sympystr(self, printer):
return f"log({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
class SympyExp(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -10, 10)
return sp.exp(z)
return None
@staticmethod
def numerical_eval(z):
z = np.clip(z, -10, 10)
return np.exp(z)
def _sympystr(self, printer):
return f"exp({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Abs(z)
@staticmethod
def numerical_eval(z):
return np.abs(z)

View File

View File

@@ -3,6 +3,10 @@ import jax.numpy as jnp
class Agg:
@staticmethod
def name2func(name):
return getattr(Agg, name)
@staticmethod
def sum(z):
z = jnp.where(jnp.isnan(z), 0, z)

View File

@@ -0,0 +1,69 @@
import sympy as sp
class SympySum(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z)
class SympyProduct(sp.Function):
@classmethod
def eval(cls, z):
return sp.Mul(*z)
class SympyMax(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z)
class SympyMin(sp.Function):
@classmethod
def eval(cls, z):
return sp.Min(*z)
class SympyMaxabs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z, key=sp.Abs)
class SympyMean(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
@staticmethod
def numerical_eval(args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)

View File

@@ -5,6 +5,7 @@ Only used in feed-forward networks.
import jax
from jax import jit, Array, numpy as jnp
from typing import Tuple, Set, List, Union
from .tools import fetch_first, I_INF
@@ -41,6 +42,60 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
return res
def topological_sort_python(
nodes: Union[Set[int], List[int]],
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
) -> Tuple[List[int], List[List[int]]]:
# a python version of topological_sort, use python set to store nodes and conns
# returns the topological order of the nodes and the topological layers
# written by gpt4 :)
# Make a copy of the input nodes and connections
nodes = nodes.copy()
conns = conns.copy()
# Initialize the in-degree of each node to 0
in_degree = {node: 0 for node in nodes}
# Compute the in-degree for each node
for conn in conns:
in_degree[conn[1]] += 1
topo_order = []
topo_layer = []
# Find all nodes with in-degree 0
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
while zero_in_degree_nodes:
for node in zero_in_degree_nodes:
nodes.remove(node)
zero_in_degree_nodes = sorted(
zero_in_degree_nodes
) # make sure the topo_order is from small to large
topo_layer.append(zero_in_degree_nodes.copy())
for node in zero_in_degree_nodes:
topo_order.append(node)
# Iterate over all connections and reduce the in-degree of connected nodes
for conn in list(conns):
if conn[0] == node:
in_degree[conn[1]] -= 1
conns.remove(conn)
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
# Check if there are still connections left indicating a cycle
if conns or nodes:
raise ValueError("Graph has at least one cycle, topological sort not possible")
return topo_order, topo_layer
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""