add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
211
tensorneat/examples/interpret_visualize/genome_sympy.ipynb
Normal file
211
tensorneat/examples/interpret_visualize/genome_sympy.ipynb
Normal file
@@ -0,0 +1,211 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"from algorithm.neat import *\n",
|
||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"from utils.graph import topological_sort_python\n",
|
||||
"from utils import Act, Agg"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:46.886073700Z",
|
||||
"start_time": "2024-06-12T11:35:46.042288800Z"
|
||||
}
|
||||
},
|
||||
"id": "9531a569d9ecf774"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"genome = AdvanceInitialize(\n",
|
||||
" num_inputs=3,\n",
|
||||
" num_outputs=1,\n",
|
||||
" hidden_cnt=1,\n",
|
||||
" max_nodes=50,\n",
|
||||
" max_conns=500,\n",
|
||||
" node_gene=NodeGeneWithoutResponse(\n",
|
||||
" # activation_default=Act.tanh,\n",
|
||||
" aggregation_default=Agg.sum,\n",
|
||||
" # activation_options=(Act.tanh,),\n",
|
||||
" aggregation_options=(Agg.sum,),\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"state = genome.setup()\n",
|
||||
"\n",
|
||||
"randkey = jax.random.PRNGKey(42)\n",
|
||||
"nodes, conns = genome.initialize(state, randkey)\n",
|
||||
"\n",
|
||||
"network = genome.network_dict(state, nodes, conns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.274062400Z",
|
||||
"start_time": "2024-06-12T11:35:46.892042200Z"
|
||||
}
|
||||
},
|
||||
"id": "4013c9f9d5472eb7"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sympy as sp\n",
|
||||
"\n",
|
||||
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
|
||||
"output_exprs"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.325161800Z",
|
||||
"start_time": "2024-06-12T11:35:52.282008300Z"
|
||||
}
|
||||
},
|
||||
"id": "addea793fc002900"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sp.latex(output_exprs[0]))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.341639700Z",
|
||||
"start_time": "2024-06-12T11:35:52.323163700Z"
|
||||
}
|
||||
},
|
||||
"id": "967cb87e24373f77"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "88eee4db9eb857cd"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.7940936986556304]"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"random_inputs = np.random.randn(3)\n",
|
||||
"res = forward_func(random_inputs)\n",
|
||||
"res "
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.342638Z",
|
||||
"start_time": "2024-06-12T11:35:52.330160600Z"
|
||||
}
|
||||
},
|
||||
"id": "c5581201d990ba1c"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"transformed = genome.transform(state, nodes, conns)\n",
|
||||
"res = genome.forward(state, transformed, random_inputs)\n",
|
||||
"res"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.273851900Z",
|
||||
"start_time": "2024-06-12T11:35:52.384588600Z"
|
||||
}
|
||||
},
|
||||
"id": "fe3449a5bc688bc3"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.274854100Z",
|
||||
"start_time": "2024-06-12T11:35:53.265856700Z"
|
||||
}
|
||||
},
|
||||
"id": "174c7dc3d9499f95"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
30
tensorneat/examples/interpret_visualize/genome_sympy.py
Normal file
30
tensorneat/examples/interpret_visualize/genome_sympy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.advance import AdvanceInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
|
||||
if __name__ == '__main__':
|
||||
genome = AdvanceInitialize(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
hidden_cnt=8,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
)
|
||||
|
||||
state = genome.setup()
|
||||
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
print(set(network["nodes"]), set(network["conns"]))
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
print(order)
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
print(input_idx, output_idx)
|
||||
|
||||
print(genome.repr(state, nodes, conns))
|
||||
print(network)
|
||||
2455
tensorneat/examples/interpret_visualize/graph.svg
Normal file
2455
tensorneat/examples/interpret_visualize/graph.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 90 KiB |
191
tensorneat/examples/interpret_visualize/network.json
Normal file
191
tensorneat/examples/interpret_visualize/network.json
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"nodes": {
|
||||
"0": {
|
||||
"bias": 0.13710324466228485,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"1": {
|
||||
"bias": -1.4202250242233276,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"2": {
|
||||
"bias": -0.4653860926628113,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"3": {
|
||||
"bias": 0.5835710167884827,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"4": {
|
||||
"bias": 2.187405824661255,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"5": {
|
||||
"bias": 0.24963024258613586,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"6": {
|
||||
"bias": -0.966821551322937,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"7": {
|
||||
"bias": 0.4452081620693207,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"8": {
|
||||
"bias": -0.07293166220188141,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"9": {
|
||||
"bias": -0.1625899225473404,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"10": {
|
||||
"bias": -0.8576332330703735,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"11": {
|
||||
"bias": -0.18487468361854553,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"12": {
|
||||
"bias": 1.4335486888885498,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"13": {
|
||||
"bias": -0.8690621256828308,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"14": {
|
||||
"bias": -0.23014676570892334,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"15": {
|
||||
"bias": 0.7880322337150574,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"16": {
|
||||
"bias": -0.22258250415325165,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"17": {
|
||||
"bias": 0.2773352861404419,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"18": {
|
||||
"bias": -0.40279051661491394,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"19": {
|
||||
"bias": 1.092000961303711,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"20": {
|
||||
"bias": -0.4063087999820709,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"21": {
|
||||
"bias": 0.3895529806613922,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"22": {
|
||||
"bias": -0.18007506430149078,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"23": {
|
||||
"bias": -0.8112533092498779,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"24": {
|
||||
"bias": 0.2946726381778717,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"25": {
|
||||
"bias": -1.118497371673584,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"26": {
|
||||
"bias": 1.3674490451812744,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"27": {
|
||||
"bias": -1.6514816284179688,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"28": {
|
||||
"bias": 0.9440701603889465,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"29": {
|
||||
"bias": 1.564852237701416,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"30": {
|
||||
"bias": -0.5568665266036987,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
}
|
||||
},
|
||||
"conns": {
|
||||
|
||||
2455
tensorneat/examples/interpret_visualize/network.svg
Normal file
2455
tensorneat/examples/interpret_visualize/network.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 89 KiB |
103
tensorneat/examples/interpret_visualize/visualize_genome.ipynb
Normal file
103
tensorneat/examples/interpret_visualize/visualize_genome.ipynb
Normal file
File diff suppressed because one or more lines are too long
13
tensorneat/examples/interpret_visualize/visualize_genome.py
Normal file
13
tensorneat/examples/interpret_visualize/visualize_genome.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 创建一个空白的有向图
|
||||
G = nx.DiGraph()
|
||||
|
||||
# 添加边
|
||||
G.add_edge('A', 'B')
|
||||
G.add_edge('A', 'C')
|
||||
G.add_edge('B', 'C')
|
||||
G.add_edge('C', 'D')
|
||||
|
||||
# 绘制有向图
|
||||
@@ -2,19 +2,19 @@ import jax, jax.numpy as jnp
|
||||
import jax.random
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
|
||||
|
||||
def random_policy(state, params, obs):
|
||||
# key = jax.random.key(obs.sum())
|
||||
# actions = jax.random.normal(key, (4,))
|
||||
key = jax.random.key(obs.sum())
|
||||
actions = jax.random.normal(key, (4,))
|
||||
# actions = actions.at[2:].set(-9999)
|
||||
return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([1, 2, 3, 4])
|
||||
# return actions
|
||||
return actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = Jumanji_2048(
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
||||
)
|
||||
state = problem.setup()
|
||||
jit_evaluate = jax.jit(
|
||||
|
||||
@@ -78,36 +78,37 @@ if __name__ == "__main__":
|
||||
Act.identity,
|
||||
),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum,),
|
||||
aggregation_options=(Agg.sum, ),
|
||||
activation_replace_rate=0.02,
|
||||
aggregation_replace_rate=0.02,
|
||||
bias_mutate_rate=0.03,
|
||||
bias_init_std=0.5,
|
||||
bias_mutate_power=0.2,
|
||||
bias_mutate_power=0.02,
|
||||
bias_replace_rate=0.01,
|
||||
),
|
||||
conn_gene=DefaultConnGene(
|
||||
weight_mutate_rate=0.015,
|
||||
weight_replace_rate=0.003,
|
||||
weight_mutate_power=0.5,
|
||||
weight_replace_rate=0.03,
|
||||
weight_mutate_power=0.05,
|
||||
),
|
||||
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=5,
|
||||
survival_threshold=0.1,
|
||||
survival_threshold=0.01,
|
||||
max_stagnation=7,
|
||||
genome_elitism=3,
|
||||
compatibility_threshold=1.2,
|
||||
),
|
||||
),
|
||||
problem=Jumanji_2048(
|
||||
max_step=10000,
|
||||
repeat_times=10,
|
||||
guarantee_invalid_action=True,
|
||||
max_step=1000,
|
||||
repeat_times=50,
|
||||
# guarantee_invalid_action=True,
|
||||
guarantee_invalid_action=False,
|
||||
action_policy=action_policy,
|
||||
),
|
||||
generation_limit=1000,
|
||||
generation_limit=10000,
|
||||
fitness_target=13000,
|
||||
save_path="2048.npz",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user