add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

4
.gitignore vendored
View File

@@ -7,3 +7,7 @@ __pycache__/
*.log *.log
.cache/ .cache/
.tmp/ .tmp/
# Ignore files named exactly 'tmp' or 'aux'
tmp
aux

View File

@@ -31,3 +31,13 @@ class BaseConnGene(BaseGene):
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format( return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
) )
def to_dict(self, state, conn):
in_idx, out_idx = conn[:2]
return {
"in": int(in_idx),
"out": int(out_idx),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
raise NotImplementedError

View File

@@ -76,3 +76,17 @@ class DefaultConnGene(BaseConnGene):
idx_width=idx_width, idx_width=idx_width,
float_width=precision + 3, float_width=precision + 3,
) )
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"weight": float(conn[2]),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
weight = conn_dict["weight"]
if precision is not None:
weight = round(weight, precision)
return inputs * weight

View File

@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
return "{}(idx={:<{idx_width}})".format( return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width self.__class__.__name__, idx, idx_width=idx_width
) )
def to_dict(self, state, node):
idx = node[0]
return {
"idx": int(idx),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
raise NotImplementedError

View File

@@ -1,8 +1,18 @@
from typing import Tuple from typing import Tuple
import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene from . import BaseNodeGene
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_default = aggregation_options.index(aggregation_default) self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default) self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options)) self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state): def new_identity_attrs(self, state):
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
act_func.__name__, act_func.__name__,
idx_width=idx_width, idx_width=idx_width,
float_width=precision + 3, float_width=precision + 3,
func_width=func_width func_width=func_width,
) )
def to_dict(self, state, node):
idx, bias, res, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"res": float(res),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
res = node_dict["res"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
res = round(res, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z * res
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -2,7 +2,16 @@ from typing import Tuple
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene from . import BaseNodeGene
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
float_width=precision + 3, float_width=precision + 3,
func_width=func_width, func_width=func_width,
) )
def to_dict(self, state, node):
idx, bias, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
def forward(self, state, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs) return Agg.sum(inputs)
def repr(self, state, node, precision=2):
idx = node[0]
idx = int(idx)
return "{}(idx: {})".format(self.__class__.__name__, idx)

View File

@@ -2,7 +2,7 @@ import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover from ..ga import BaseMutation, BaseCrossover
from utils import State, StatefulBaseClass from utils import State, StatefulBaseClass, topological_sort_python
class BaseGenome(StatefulBaseClass): class BaseGenome(StatefulBaseClass):
@@ -155,3 +155,112 @@ class BaseGenome(StatefulBaseClass):
@classmethod @classmethod
def valid_cnt(cls, arr): def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0])) return jnp.sum(~jnp.isnan(arr[:, 0]))
def get_conn_dict(self, state, conns):
conns = jax.device_get(conns)
conn_dict = {}
for conn in conns:
if np.isnan(conn[0]):
continue
cd = self.conn_gene.to_dict(state, conn)
in_idx, out_idx = cd["in"], cd["out"]
del cd["in"], cd["out"]
conn_dict[(in_idx, out_idx)] = cd
return conn_dict
def get_node_dict(self, state, nodes):
nodes = jax.device_get(nodes)
node_dict = {}
for node in nodes:
if np.isnan(node[0]):
continue
nd = self.node_gene.to_dict(state, node)
idx = nd["idx"]
del nd["idx"]
node_dict[idx] = nd
return node_dict
def network_dict(self, state, nodes, conns):
return {
"nodes": self.get_node_dict(state, nodes),
"conns": self.get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def sympy_func(self, state, network, precision=3):
raise NotImplementedError
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
with_labels=True,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -1,17 +1,19 @@
from typing import Callable from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import sympy as sp
from utils import ( from utils import (
unflatten_conns, unflatten_conns,
topological_sort, topological_sort,
topological_sort_python,
I_INF, I_INF,
extract_node_attrs, extract_node_attrs,
extract_conn_attrs, extract_conn_attrs,
set_node_attrs, set_node_attrs,
set_conn_attrs, set_conn_attrs,
attach_with_inf, attach_with_inf,
FUNCS_MODULE,
) )
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
@@ -188,3 +190,56 @@ class DefaultGenome(BaseGenome):
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]), jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
new_transformed, new_transformed,
) )
def sympy_func(self, state, network, precision=3):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i}")
nodes_exprs = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[i]] = symbols[i]
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
val = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
precision=precision,
)
node_inputs.append(val)
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
precision=precision,
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
for exprs in output_exprs
]
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func

View File

@@ -84,3 +84,6 @@ class RecurrentGenome(BaseGenome):
return vals[self.output_idx] return vals[self.output_idx]
else: else:
return self.output_transform(vals[self.output_idx]) return self.output_transform(vals[self.output_idx])
def sympy_func(self, state, network, precision=3):
raise ValueError("Sympy function is not supported for Recurrent Network!")

View File

@@ -0,0 +1,211 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"from algorithm.neat import *\n",
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"from utils.graph import topological_sort_python\n",
"from utils import Act, Agg"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:46.886073700Z",
"start_time": "2024-06-12T11:35:46.042288800Z"
}
},
"id": "9531a569d9ecf774"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"genome = AdvanceInitialize(\n",
" num_inputs=3,\n",
" num_outputs=1,\n",
" hidden_cnt=1,\n",
" max_nodes=50,\n",
" max_conns=500,\n",
" node_gene=NodeGeneWithoutResponse(\n",
" # activation_default=Act.tanh,\n",
" aggregation_default=Agg.sum,\n",
" # activation_options=(Act.tanh,),\n",
" aggregation_options=(Agg.sum,),\n",
" )\n",
")\n",
"\n",
"state = genome.setup()\n",
"\n",
"randkey = jax.random.PRNGKey(42)\n",
"nodes, conns = genome.initialize(state, randkey)\n",
"\n",
"network = genome.network_dict(state, nodes, conns)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.274062400Z",
"start_time": "2024-06-12T11:35:46.892042200Z"
}
},
"id": "4013c9f9d5472eb7"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sympy as sp\n",
"\n",
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
"output_exprs"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.325161800Z",
"start_time": "2024-06-12T11:35:52.282008300Z"
}
},
"id": "addea793fc002900"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
]
}
],
"source": [
"print(sp.latex(output_exprs[0]))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.341639700Z",
"start_time": "2024-06-12T11:35:52.323163700Z"
}
},
"id": "967cb87e24373f77"
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
},
"id": "88eee4db9eb857cd"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "[-0.7940936986556304]"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"random_inputs = np.random.randn(3)\n",
"res = forward_func(random_inputs)\n",
"res "
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:52.342638Z",
"start_time": "2024-06-12T11:35:52.330160600Z"
}
},
"id": "c5581201d990ba1c"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed = genome.transform(state, nodes, conns)\n",
"res = genome.forward(state, transformed, random_inputs)\n",
"res"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:53.273851900Z",
"start_time": "2024-06-12T11:35:52.384588600Z"
}
},
"id": "fe3449a5bc688bc3"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-12T11:35:53.274854100Z",
"start_time": "2024-06-12T11:35:53.265856700Z"
}
},
"id": "174c7dc3d9499f95"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,30 @@
import jax, jax.numpy as jnp
from algorithm.neat import *
from algorithm.neat.genome.advance import AdvanceInitialize
from utils.graph import topological_sort_python
if __name__ == '__main__':
genome = AdvanceInitialize(
num_inputs=17,
num_outputs=6,
hidden_cnt=8,
max_nodes=50,
max_conns=500,
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
print(set(network["nodes"]), set(network["conns"]))
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
print(order)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
print(input_idx, output_idx)
print(genome.repr(state, nodes, conns))
print(network)

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 90 KiB

View File

@@ -0,0 +1,191 @@
{
"nodes": {
"0": {
"bias": 0.13710324466228485,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"1": {
"bias": -1.4202250242233276,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"2": {
"bias": -0.4653860926628113,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"3": {
"bias": 0.5835710167884827,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"4": {
"bias": 2.187405824661255,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"5": {
"bias": 0.24963024258613586,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"6": {
"bias": -0.966821551322937,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"7": {
"bias": 0.4452081620693207,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"8": {
"bias": -0.07293166220188141,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"9": {
"bias": -0.1625899225473404,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"10": {
"bias": -0.8576332330703735,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"11": {
"bias": -0.18487468361854553,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"12": {
"bias": 1.4335486888885498,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"13": {
"bias": -0.8690621256828308,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"14": {
"bias": -0.23014676570892334,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"15": {
"bias": 0.7880322337150574,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"16": {
"bias": -0.22258250415325165,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"17": {
"bias": 0.2773352861404419,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"18": {
"bias": -0.40279051661491394,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"19": {
"bias": 1.092000961303711,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"20": {
"bias": -0.4063087999820709,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"21": {
"bias": 0.3895529806613922,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"22": {
"bias": -0.18007506430149078,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"23": {
"bias": -0.8112533092498779,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"24": {
"bias": 0.2946726381778717,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"25": {
"bias": -1.118497371673584,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"26": {
"bias": 1.3674490451812744,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"27": {
"bias": -1.6514816284179688,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"28": {
"bias": 0.9440701603889465,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"29": {
"bias": 1.564852237701416,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"30": {
"bias": -0.5568665266036987,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
}
},
"conns": {

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 89 KiB

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
import networkx as nx
import matplotlib.pyplot as plt
# 创建一个空白的有向图
G = nx.DiGraph()
# 添加边
G.add_edge('A', 'B')
G.add_edge('A', 'C')
G.add_edge('B', 'C')
G.add_edge('C', 'D')
# 绘制有向图

View File

@@ -2,19 +2,19 @@ import jax, jax.numpy as jnp
import jax.random import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs): def random_policy(state, params, obs):
# key = jax.random.key(obs.sum()) key = jax.random.key(obs.sum())
# actions = jax.random.normal(key, (4,)) actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999) # actions = actions.at[2:].set(-9999)
return jnp.array([4, 4, 0, 1]) # return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4]) # return jnp.array([1, 2, 3, 4])
# return actions
return actions return actions
if __name__ == "__main__": if __name__ == "__main__":
problem = Jumanji_2048( problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=True max_step=10000, repeat_times=1000, guarantee_invalid_action=False
) )
state = problem.setup() state = problem.setup()
jit_evaluate = jax.jit( jit_evaluate = jax.jit(

View File

@@ -78,36 +78,37 @@ if __name__ == "__main__":
Act.identity, Act.identity,
), ),
aggregation_default=Agg.sum, aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,), aggregation_options=(Agg.sum, ),
activation_replace_rate=0.02, activation_replace_rate=0.02,
aggregation_replace_rate=0.02, aggregation_replace_rate=0.02,
bias_mutate_rate=0.03, bias_mutate_rate=0.03,
bias_init_std=0.5, bias_init_std=0.5,
bias_mutate_power=0.2, bias_mutate_power=0.02,
bias_replace_rate=0.01, bias_replace_rate=0.01,
), ),
conn_gene=DefaultConnGene( conn_gene=DefaultConnGene(
weight_mutate_rate=0.015, weight_mutate_rate=0.015,
weight_replace_rate=0.003, weight_replace_rate=0.03,
weight_mutate_power=0.5, weight_mutate_power=0.05,
), ),
mutation=DefaultMutation(node_add=0.001, conn_add=0.002), mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
), ),
pop_size=1000, pop_size=1000,
species_size=5, species_size=5,
survival_threshold=0.1, survival_threshold=0.01,
max_stagnation=7, max_stagnation=7,
genome_elitism=3, genome_elitism=3,
compatibility_threshold=1.2, compatibility_threshold=1.2,
), ),
), ),
problem=Jumanji_2048( problem=Jumanji_2048(
max_step=10000, max_step=1000,
repeat_times=10, repeat_times=50,
guarantee_invalid_action=True, # guarantee_invalid_action=True,
guarantee_invalid_action=False,
action_policy=action_policy, action_policy=action_policy,
), ),
generation_limit=1000, generation_limit=10000,
fitness_target=13000, fitness_target=13000,
save_path="2048.npz", save_path="2048.npz",
) )

View File

@@ -18,8 +18,7 @@ class Jumanji_2048(RLEnv):
################################################################### ###################################################################
action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)]) # action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
action = (action - 1) / 15
################################################################### ###################################################################
if self.guarantee_invalid_action: if self.guarantee_invalid_action:

View File

@@ -2,189 +2,110 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:39.434327400Z",
"start_time": "2024-06-06T11:55:39.361327400Z"
}
},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)" "text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
}, },
"execution_count": 22, "execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"import jax, jax.numpy as jnp\n", "import jax, jax.numpy as jnp\n",
"a = jnp.array([\n", "from algorithm.neat import *\n",
" [1, 2],\n", "from utils import Act, Agg\n",
" [3, 4]\n", "genome = DefaultGenome(\n",
"])\n", " num_inputs=27,\n",
"def rot_boards(board):\n", " num_outputs=8,\n",
" def rot(a, _):\n", " max_nodes=100,\n",
" a = jnp.rot90(a)\n", " max_conns=200,\n",
" return a, a # carry, y\n", " node_gene=DefaultNodeGene(\n",
" \n", " activation_options=(Act.tanh,),\n",
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n", " activation_default=Act.tanh,\n",
" return boards\n", " ),\n",
"a1 = rot_boards(a)\n", " output_transform=Act.tanh,\n",
"a2 = rot_boards(a)\n", ")\n",
"\n", "state = genome.setup()\n",
"a = jnp.concatenate([a1, a2], axis=0)\n", "genome"
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = a.reshape(8, -1)\n",
"a"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-06-06T11:55:31.121054800Z", "end_time": "2024-06-09T12:08:22.569123400Z",
"start_time": "2024-06-06T11:55:31.075517200Z" "start_time": "2024-06-09T12:08:19.331863800Z"
} }
}, },
"id": "639cdecea840351d" "id": "b2b214a5454c4814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"state = state.register(data=jnp.zeros((1, 27)))\n",
"# try to save the genome object\n",
"import pickle\n",
"\n",
"with open('genome.pkl', 'wb') as f:\n",
" genome.__dict__[\"state\"] = state\n",
" pickle.dump(genome, f)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:09:01.943445900Z",
"start_time": "2024-06-09T12:09:01.919416Z"
}
},
"id": "28348dfc458e8473"
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 13,
"outputs": [], "outputs": [],
"source": [ "source": [
"action = [\"up\", \"right\", \"down\", \"left\"]\n", "# try to load the genome object\n",
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n", "with open('genome.pkl', 'rb') as f:\n",
"def action_rot90(li):\n", " genome = pickle.load(f)\n",
" first = li[0]\n", " state = genome.state\n",
" return li[1:] + [first]\n", " del genome.__dict__[\"state\"]"
"\n",
"a = a\n",
"rl_flip_a = jnp.fliplr(a)"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-06-06T11:22:36.417287600Z", "end_time": "2024-06-09T12:10:28.621539400Z",
"start_time": "2024-06-06T11:22:36.414285500Z" "start_time": "2024-06-09T12:10:28.612540100Z"
} }
}, },
"id": "92b75cd0e870a28c" "id": "c91be9fe3d2b5d5d"
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1 2]\n",
" [3 4]] ['up', 'right', 'down', 'left']\n",
"[[2 1]\n",
" [4 3]] ['up', 'left', 'down', 'right']\n",
"[[2 4]\n",
" [1 3]] ['right', 'down', 'left', 'up']\n",
"[[1 3]\n",
" [2 4]] ['left', 'down', 'right', 'up']\n",
"[[4 3]\n",
" [2 1]] ['down', 'left', 'up', 'right']\n",
"[[3 4]\n",
" [1 2]] ['down', 'right', 'up', 'left']\n",
"[[3 1]\n",
" [4 2]] ['left', 'up', 'right', 'down']\n",
"[[4 2]\n",
" [3 1]] ['right', 'up', 'left', 'down']\n"
]
}
],
"source": [
"for i in range(4):\n",
" print(a, action)\n",
" print(rl_flip_a, lr_flip_action)\n",
" a = jnp.rot90(a)\n",
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
" action = action_rot90(action)\n",
" lr_flip_action = action_rot90(lr_flip_action)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.919614600Z",
"start_time": "2024-06-06T11:22:36.860704600Z"
}
},
"id": "55e802e0dbcc9c7f"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)" "text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
}, },
"execution_count": 6, "execution_count": 15,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"jnp.rot90(a, k=2)" "state"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"ExecuteTime": { "ExecuteTime": {
"end_time": "2024-06-06T11:12:48.186719Z", "end_time": "2024-06-09T12:10:34.103124Z",
"start_time": "2024-06-06T11:12:48.151161900Z" "start_time": "2024-06-09T12:10:34.096124300Z"
} }
}, },
"id": "16f8de3cadaa257a" "id": "6852e4e58b81dd9"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# flip left-right\n",
"jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:14:28.668195300Z",
"start_time": "2024-06-06T11:14:28.631570500Z"
}
},
"id": "1fffa4e597ab5732"
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -194,7 +115,7 @@
"metadata": { "metadata": {
"collapsed": false "collapsed": false
}, },
"id": "ca53c916dcff12ae" "id": "97a50322218a0427"
} }
], ],
"metadata": { "metadata": {

View File

@@ -1,6 +1,51 @@
from .activation import Act, act_func, ACT_ALL from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .aggregation import Agg, agg_func, AGG_ALL
from .tools import * from .tools import *
from .graph import * from .graph import *
from .state import State from .state import State
from .stateful_class import StatefulBaseClass from .stateful_class import StatefulBaseClass
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Union
name2sympy = {
"sigmoid": SympySigmoid,
"tanh": SympyTanh,
"sin": SympySin,
"relu": SympyRelu,
"lelu": SympyLelu,
"identity": SympyIdentity,
"clamped": SympyClamped,
"inv": SympyInv,
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"sum": SympySum,
"product": SympyProduct,
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
}
def convert_to_sympy(func: Union[str, callable]):
if isinstance(func, str):
name = func
else:
name = func.__name__
if name in name2sympy:
return name2sympy[name]
else:
raise ValueError(
f"Can not convert to sympy! Function {name} not found in name2sympy"
)
FUNCS_MODULE = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
FUNCS_MODULE[cls.__name__] = cls.numerical_eval

View File

View File

@@ -3,6 +3,10 @@ import jax.numpy as jnp
class Act: class Act:
@staticmethod
def name2func(name):
return getattr(Act, name)
@staticmethod @staticmethod
def sigmoid(z): def sigmoid(z):
z = jnp.clip(5 * z, -10, 10) z = jnp.clip(5 * z, -10, 10)

View File

@@ -0,0 +1,191 @@
from typing import Union
import sympy as sp
import numpy as np
class SympyClip(sp.Function):
@classmethod
def eval(cls, val, min_val, max_val):
if val.is_Number and min_val.is_Number and max_val.is_Number:
return sp.Piecewise(
(min_val, val < min_val), (max_val, val > max_val), (val, True)
)
return None
@staticmethod
def numerical_eval(val, min_val, max_val):
return np.clip(val, min_val, max_val)
def _sympystr(self, printer):
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
def _latex(self, printer):
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
class SympySigmoid(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z, -10, 10)
return 1 / (1 + sp.exp(-z))
return None
@staticmethod
def numerical_eval(z):
z = np.clip(5 * z, -10, 10)
return 1 / (1 + np.exp(-z))
def _sympystr(self, printer):
return f"sigmoid({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
return sp.tanh(0.6 * z)
@staticmethod
def numerical_eval(z):
return np.tanh(0.6 * z)
class SympySin(sp.Function):
@classmethod
def eval(cls, z):
return sp.sin(z)
@staticmethod
def numerical_eval(z):
return np.sin(z)
class SympyRelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
return sp.Piecewise((z, z > 0), (0, True))
return None
@staticmethod
def numerical_eval(z):
return np.maximum(z, 0)
def _sympystr(self, printer):
return f"relu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
class SympyLelu(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
leaky = 0.005
return sp.Piecewise((z, z > 0), (leaky * z, True))
return None
@staticmethod
def numerical_eval(z):
leaky = 0.005
return np.maximum(z, leaky * z)
def _sympystr(self, printer):
return f"lelu({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
return z
@staticmethod
def numerical_eval(z):
return z
class SympyClamped(sp.Function):
@classmethod
def eval(cls, z):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z):
return np.clip(z, -1, 1)
class SympyInv(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
return 1 / z
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return 1 / z
def _sympystr(self, printer):
return f"1 / {self.args[0]}"
def _latex(self, printer):
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
class SympyLog(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = sp.Max(z, 1e-7)
return sp.log(z)
return None
@staticmethod
def numerical_eval(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def _sympystr(self, printer):
return f"log({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
class SympyExp(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -10, 10)
return sp.exp(z)
return None
@staticmethod
def numerical_eval(z):
z = np.clip(z, -10, 10)
return np.exp(z)
def _sympystr(self, printer):
return f"exp({self.args[0]})"
def _latex(self, printer):
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Abs(z)
@staticmethod
def numerical_eval(z):
return np.abs(z)

View File

View File

@@ -3,6 +3,10 @@ import jax.numpy as jnp
class Agg: class Agg:
@staticmethod
def name2func(name):
return getattr(Agg, name)
@staticmethod @staticmethod
def sum(z): def sum(z):
z = jnp.where(jnp.isnan(z), 0, z) z = jnp.where(jnp.isnan(z), 0, z)

View File

@@ -0,0 +1,69 @@
import sympy as sp
class SympySum(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z)
class SympyProduct(sp.Function):
@classmethod
def eval(cls, z):
return sp.Mul(*z)
class SympyMax(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z)
class SympyMin(sp.Function):
@classmethod
def eval(cls, z):
return sp.Min(*z)
class SympyMaxabs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z, key=sp.Abs)
class SympyMean(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
@staticmethod
def numerical_eval(args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)

View File

@@ -5,6 +5,7 @@ Only used in feed-forward networks.
import jax import jax
from jax import jit, Array, numpy as jnp from jax import jit, Array, numpy as jnp
from typing import Tuple, Set, List, Union
from .tools import fetch_first, I_INF from .tools import fetch_first, I_INF
@@ -41,6 +42,60 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
return res return res
def topological_sort_python(
nodes: Union[Set[int], List[int]],
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
) -> Tuple[List[int], List[List[int]]]:
# a python version of topological_sort, use python set to store nodes and conns
# returns the topological order of the nodes and the topological layers
# written by gpt4 :)
# Make a copy of the input nodes and connections
nodes = nodes.copy()
conns = conns.copy()
# Initialize the in-degree of each node to 0
in_degree = {node: 0 for node in nodes}
# Compute the in-degree for each node
for conn in conns:
in_degree[conn[1]] += 1
topo_order = []
topo_layer = []
# Find all nodes with in-degree 0
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
while zero_in_degree_nodes:
for node in zero_in_degree_nodes:
nodes.remove(node)
zero_in_degree_nodes = sorted(
zero_in_degree_nodes
) # make sure the topo_order is from small to large
topo_layer.append(zero_in_degree_nodes.copy())
for node in zero_in_degree_nodes:
topo_order.append(node)
# Iterate over all connections and reduce the in-degree of connected nodes
for conn in list(conns):
if conn[0] == node:
in_degree[conn[1]] -= 1
conns.remove(conn)
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
# Check if there are still connections left indicating a cycle
if conns or nodes:
raise ValueError("Graph has at least one cycle, topological sort not possible")
return topo_order, topo_layer
@jit @jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
""" """