update a lot, take a break

This commit is contained in:
root
2024-07-12 07:47:33 +08:00
parent 58c56ab2ab
commit 99b8f7fd90
11 changed files with 2161 additions and 2418 deletions

View File

@@ -26,7 +26,7 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
- JAX-based network for neuroevolution:
- **Batch inference** across networks with different architectures, GPU-accelerated.
- Evolve networks with **irregular structures** and **fully customize** their behavior.
- Visualize the network and represent it in **mathematical formulas**.
- Visualize the network and represent it in **mathematical formulas** or codes.
- GPU-accelerated NEAT implementation:
- Run NEAT and HyperNEAT on GPUs.
@@ -75,6 +75,60 @@ state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)
```
Obtain result in a few generations:
```
Fitness limit reached!
input: [0. 0. 0.], target: [0.], predict: [0.00066471]
input: [0. 0. 1.], target: [1.], predict: [0.9992988]
input: [0. 1. 0.], target: [1.], predict: [0.9988666]
input: [0. 1. 1.], target: [0.], predict: [0.00107922]
input: [1. 0. 0.], target: [1.], predict: [0.9987184]
input: [1. 0. 1.], target: [0.], predict: [0.00093677]
input: [1. 1. 0.], target: [0.], predict: [0.00060118]
input: [1. 1. 1.], target: [1.], predict: [0.99927646]
loss: 8.484730074087565e-07
```
4. **Visualize the best network**:
```python
network = algorithm.genome.network_dict(state, *best)
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
```
<div style="text-align: center;">
<img src="./imgs/xor_network.svg" alt="Visualization of the policy" width="200" height="200">
</div>
5. **Transform the network to latex formulas or python codes**:
```python
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
sympy_res = algorithm.genome.sympy_func(
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
)
latex_code = to_latex_code(*sympy_res)
print(latex_code)
python_code = to_python_code(*sympy_res)
print(python_code)
```
Obtain latex formulas:
```latex
\begin{align}
h_{0} &= \frac{1}{2.83 e^{5.66 h_{1} - 6.08 h_{2} - 3.03 i_{2}} + 1}\newline
h_{1} &= \frac{1}{0.3 e^{- 4.8 h_{2} + 9.22 i_{0} + 8.09 i_{1} - 10.24 i_{2}} + 1}\newline
h_{2} &= \frac{1}{0.27 e^{4.28 i_{1}} + 1}\newline
o_{0} &= \frac{1}{0.68 e^{- 20.86 h_{0} + 11.12 h_{1} + 14.22 i_{0} - 1.96 i_{2}} + 1}\newline
\end{align}
```
Obtain python codes:
```python
h = np.zeros(3)
o = np.zeros(1)
h[0] = 1/(2.825013*exp(5.660946*h[1] - 6.083459*h[2] - 3.033361*i[2]) + 1)
h[1] = 1/(0.300038*exp(-4.802896*h[2] + 9.215506*i[0] + 8.091845*i[1] - 10.241107*i[2]) + 1)
h[2] = 1/(0.269965*exp(4.279962*i[1]) + 1)
o[0] = 1/(0.679321*exp(-20.860441*h[0] + 11.122242*h[1] + 14.216276*i[0] - 1.961642*i[2]) + 1)
```
## Installation
Install `tensorneat` from the GitHub source code:

View File

@@ -1,5 +1,6 @@
from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem, common
from tensorneat import algorithm, genome, problem
from tensorneat.common import ACT
algorithm = algorithm.NEAT(
pop_size=10000,
@@ -8,7 +9,8 @@ algorithm = algorithm.NEAT(
genome=genome.DefaultGenome(
num_inputs=3,
num_outputs=1,
output_transform=common.ACT.sigmoid,
max_nodes=7,
output_transform=ACT.sigmoid,
),
)
problem = problem.XOR3d()
@@ -25,3 +27,20 @@ state = pipeline.setup()
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
# visualize the best individual
network = algorithm.genome.network_dict(state, *best)
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
# transform the best individual to latex formula
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
sympy_res = algorithm.genome.sympy_func(
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
)
latex_code = to_latex_code(*sympy_res)
print(latex_code)
# transform the best individual to python code
python_code = to_python_code(*sympy_res)
print(python_code)

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 89 KiB

After

Width:  |  Height:  |  Size: 18 KiB

File diff suppressed because one or more lines are too long

View File

@@ -1,13 +1,9 @@
import networkx as nx
import matplotlib.pyplot as plt
# 创建一个空白的有向图
G = nx.DiGraph()
# 添加边
G.add_edge('A', 'B')
G.add_edge('A', 'C')
G.add_edge('B', 'C')
G.add_edge('C', 'D')
# 绘制有向图
h = np.zeros(3)
o = np.zeros(4)
h[0] = -tanh(0.540486*i[0] + 1.04397*i[1] + 0.58006*i[10] + 0.658223*i[11] - 0.9918*i[12] - 0.01919*i[13] + 0.194062*i[14] + 0.903314*i[15] - 1.906567*i[2] - 1.666336*i[3] + 0.653257*i[4] + 0.580191*i[5] + 0.177264*i[6] + 0.830688*i[7] - 0.855676*i[8] + 0.326538*i[9] + 2.465507)
h[1] = -tanh(1.441044*i[0] - 0.606109*i[1] - 0.736058*i[10] + 0.60264*i[11] - 0.837565*i[12] + 2.018719*i[13] + 0.327097*i[14] + 0.098963*i[15] + 0.403485*i[2] - 0.680547*i[3] + 0.349021*i[4] - 1.359364*i[5] + 0.351466*i[6] + 0.450447*i[7] + 2.102749*i[8] + 0.680605*i[9] + 0.593945)
h[2] = -tanh(1.350645*i[0] - 0.281682*i[1] + 0.332992*i[10] + 0.703457*i[11] + 1.290286*i[12] - 1.059887*i[13] - 1.114513*i[14] + 0.446127*i[15] + 1.103008*i[2] + 1.080698*i[3] - 0.89471*i[4] + 0.103146*i[5] - 0.828767*i[6] + 0.609362*i[7] - 0.765917*i[8] + 0.047898*i[9] + 0.649254)
o[0] = -1.307307*h[0] - 0.985838*h[1] - 0.746408*h[2] + 0.245725
o[1] = 0.64947*h[0] + 2.865669*h[1] + 1.185759*h[2] - 1.347174
o[2] = 2.030407*h[0] + 1.001914*h[1] - 1.041287*h[2] + 0.301639
o[3] = 0.717661*h[0] + 0.653905*h[1] - 1.387949*h[2] - 1.200779

226
imgs/xor_network.svg Normal file
View File

@@ -0,0 +1,226 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2024-07-12T07:46:33.195425</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 345.6
L 460.8 345.6
L 460.8 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 48.095958 312.568974
Q 230.401029 244.204573 411.659252 176.232739
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 410.324982 176.09229
L 411.659252 176.232739
L 410.746331 177.215885
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_3">
<path d="M 46.916335 310.412749
Q 135.193491 244.204882 222.57622 178.667835
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 221.25622 178.907835
L 222.57622 178.667835
L 221.97622 179.867835
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_4">
<path d="M 48.647998 172.8
Q 135.193467 172.8 220.620902 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 219.420902 172.2
L 220.620902 172.8
L 219.420902 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_5">
<path d="M 48.647998 172.8
Q 87.590519 172.8 125.415005 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 124.215005 172.2
L 125.415005 172.8
L 124.215005 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_6">
<path d="M 46.916335 35.187251
Q 135.193491 101.395118 222.57622 166.932165
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 221.97622 165.732165
L 222.57622 166.932165
L 221.25622 166.692165
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_7">
<path d="M 48.095958 33.031026
Q 230.401029 101.395427 411.659252 169.367261
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 410.746331 168.384115
L 411.659252 169.367261
L 410.324982 169.50771
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_8">
<path d="M 47.731321 33.862355
Q 182.795689 101.394539 316.860058 168.426723
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 316.055073 167.35341
L 316.860058 168.426723
L 315.518417 168.426723
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_9">
<path d="M 143.85461 172.8
Q 182.79713 172.8 220.621616 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 219.421616 172.2
L 220.621616 172.8
L 219.421616 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_10">
<path d="M 143.85461 172.8
Q 230.400079 172.8 315.827513 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 314.627513 172.2
L 315.827513 172.8
L 314.627513 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_11">
<path d="M 239.061222 172.8
Q 278.003742 172.8 315.828228 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 314.628228 172.2
L 315.828228 172.8
L 314.628228 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_12">
<path d="M 239.061222 172.8
Q 325.60669 172.8 411.034125 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 409.834125 172.2
L 411.034125 172.8
L 409.834125 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_13">
<path d="M 334.267833 172.8
Q 373.210353 172.8 411.03484 172.8
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 409.83484 172.2
L 411.03484 172.8
L 409.83484 173.4
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="PathCollection_1">
<path d="M 39.986777 324.270171
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 135.193388 181.460254
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
z
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
<path d="M 230.4 181.460254
C 232.696726 181.460254 234.899694 180.547755 236.523724 178.923724
C 238.147755 177.299694 239.060254 175.096726 239.060254 172.8
C 239.060254 170.503274 238.147755 168.300306 236.523724 166.676276
C 234.899694 165.052245 232.696726 164.139746 230.4 164.139746
C 228.103274 164.139746 225.900306 165.052245 224.276276 166.676276
C 222.652245 168.300306 221.739746 170.503274 221.739746 172.8
C 221.739746 175.096726 222.652245 177.299694 224.276276 178.923724
C 225.900306 180.547755 228.103274 181.460254 230.4 181.460254
z
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
<path d="M 325.606612 181.460254
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
z
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
<path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z
" clip-path="url(#pbc29be279d)" style="fill: #0000ff; stroke: #000000"/>
</g>
</g>
</g>
<defs>
<clipPath id="pbc29be279d">
<rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 9.9 KiB

View File

@@ -0,0 +1,91 @@
import re
import sympy as sp
def analysis_nodes_exprs(nodes_exprs):
input_cnt, hidden_cnt, output_cnt = 0, 0, 0
norm_symbols = {}
for key in nodes_exprs.keys():
if str(key).startswith('i'):
input_cnt += 1
elif str(key).startswith('h'):
hidden_cnt += 1
elif str(key).startswith('o'):
output_cnt += 1
elif str(key).startswith('norm'):
norm_symbols[key] = nodes_exprs[key]
return input_cnt, hidden_cnt, output_cnt, norm_symbols
def round_expr(expr, precision=2):
"""
Round numerical values in a sympy expression to a given precision.
"""
return expr.xreplace({n: round(n, precision) for n in expr.atoms(sp.Number)})
def replace_variable_names(expression):
"""
Transform sympy expression to a string with array index that can be used in python code.
For example, `o0` will be transformed to `o[0]`.
"""
expression_str = str(expression)
expression_str = re.sub(r"\bo(\d+)\b", r"o[\1]", expression_str)
expression_str = re.sub(r"\bh(\d+)\b", r"h[\1]", expression_str)
expression_str = re.sub(r"\bi(\d+)\b", r"i[\1]", expression_str)
return expression_str
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
res = "\\begin{align}\n"
if not use_hidden_nodes:
for i in range(output_cnt):
expr = output_exprs[i].subs(args_symbols)
rounded_expr = round_expr(expr, 2)
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
res += latex_expr
else:
for i in range(hidden_cnt):
symbol = sp.symbols(f"h{i}")
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 2)
latex_expr = f"h_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
res += latex_expr
for i in range(output_cnt):
symbol = sp.symbols(f"o{i}")
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 2)
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
res += latex_expr
res += "\\end{align}\n"
return res
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
res = ""
if not use_hidden_nodes:
# pre-allocate space
res += f"o = np.zeros({output_cnt})\n"
for i in range(output_cnt):
expr = output_exprs[i].subs(args_symbols)
rounded_expr = round_expr(expr, 6)
str_expr = f"o{i} = {rounded_expr}"
res += replace_variable_names(str_expr) + "\n"
else:
# pre-allocate space
res += f"h = np.zeros({hidden_cnt})\n"
res += f"o = np.zeros({output_cnt})\n"
for i in range(hidden_cnt):
symbol = sp.symbols(f"h{i}")
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 6)
str_expr = f"h{i} = {rounded_expr}"
res += replace_variable_names(str_expr) + "\n"
for i in range(output_cnt):
symbol = sp.symbols(f"o{i}")
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
rounded_expr = round_expr(expr, 6)
str_expr = f"o{i} = {rounded_expr}"
res += replace_variable_names(str_expr) + "\n"
return res

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
StatefulBaseClass,
hash_array,
)
from .utils import valid_cnt
from .utils import valid_cnt, re_cound_idx
class BaseGenome(StatefulBaseClass):
@@ -160,7 +160,11 @@ class BaseGenome(StatefulBaseClass):
return nodes, conns
def network_dict(self, state, nodes, conns):
def network_dict(self, state, nodes, conns, whether_re_cound_idx=True):
if whether_re_cound_idx:
nodes, conns = re_cound_idx(
nodes, conns, self.get_input_idx(), self.get_output_idx()
)
return {
"nodes": self._get_node_dict(state, nodes),
"conns": self._get_conn_dict(state, conns),

View File

@@ -209,7 +209,6 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z

View File

@@ -1,5 +1,6 @@
import jax
from jax import vmap, numpy as jnp
import numpy as np
from tensorneat.common import fetch_first, I_INF
@@ -107,3 +108,33 @@ def delete_conn_by_pos(conns, pos):
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
def re_cound_idx(nodes, conns, input_idx, output_idx):
"""
Make the key of hidden nodes continuous.
Also update the index of connections.
"""
nodes, conns = jax.device_get((nodes, conns))
next_key = max(*input_idx, *output_idx) + 1
old2new = {}
for i, key in enumerate(nodes[:, 0]):
if np.isnan(key):
continue
if np.in1d(key, input_idx + output_idx):
continue
old2new[int(key)] = next_key
next_key += 1
new_nodes = nodes.copy()
for i, key in enumerate(nodes[:, 0]):
if (not np.isnan(key)) and int(key) in old2new:
new_nodes[i, 0] = old2new[int(key)]
new_conns = conns.copy()
for i, (i_key, o_key) in enumerate(conns[:, :2]):
if (not np.isnan(i_key)) and int(i_key) in old2new:
new_conns[i, 0] = old2new[int(i_key)]
if (not np.isnan(o_key)) and int(o_key) in old2new:
new_conns[i, 1] = old2new[int(o_key)]
return new_nodes, new_conns