update a lot, take a break
This commit is contained in:
56
README.md
56
README.md
@@ -26,7 +26,7 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
|
||||
- JAX-based network for neuroevolution:
|
||||
- **Batch inference** across networks with different architectures, GPU-accelerated.
|
||||
- Evolve networks with **irregular structures** and **fully customize** their behavior.
|
||||
- Visualize the network and represent it in **mathematical formulas**.
|
||||
- Visualize the network and represent it in **mathematical formulas** or codes.
|
||||
|
||||
- GPU-accelerated NEAT implementation:
|
||||
- Run NEAT and HyperNEAT on GPUs.
|
||||
@@ -75,6 +75,60 @@ state, best = pipeline.auto_run(state)
|
||||
# show results
|
||||
pipeline.show(state, best)
|
||||
```
|
||||
Obtain result in a few generations:
|
||||
```
|
||||
Fitness limit reached!
|
||||
input: [0. 0. 0.], target: [0.], predict: [0.00066471]
|
||||
input: [0. 0. 1.], target: [1.], predict: [0.9992988]
|
||||
input: [0. 1. 0.], target: [1.], predict: [0.9988666]
|
||||
input: [0. 1. 1.], target: [0.], predict: [0.00107922]
|
||||
input: [1. 0. 0.], target: [1.], predict: [0.9987184]
|
||||
input: [1. 0. 1.], target: [0.], predict: [0.00093677]
|
||||
input: [1. 1. 0.], target: [0.], predict: [0.00060118]
|
||||
input: [1. 1. 1.], target: [1.], predict: [0.99927646]
|
||||
loss: 8.484730074087565e-07
|
||||
```
|
||||
4. **Visualize the best network**:
|
||||
```python
|
||||
network = algorithm.genome.network_dict(state, *best)
|
||||
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
|
||||
```
|
||||
<div style="text-align: center;">
|
||||
<img src="./imgs/xor_network.svg" alt="Visualization of the policy" width="200" height="200">
|
||||
</div>
|
||||
|
||||
5. **Transform the network to latex formulas or python codes**:
|
||||
```python
|
||||
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
|
||||
|
||||
sympy_res = algorithm.genome.sympy_func(
|
||||
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
|
||||
)
|
||||
latex_code = to_latex_code(*sympy_res)
|
||||
print(latex_code)
|
||||
|
||||
python_code = to_python_code(*sympy_res)
|
||||
print(python_code)
|
||||
```
|
||||
Obtain latex formulas:
|
||||
```latex
|
||||
\begin{align}
|
||||
h_{0} &= \frac{1}{2.83 e^{5.66 h_{1} - 6.08 h_{2} - 3.03 i_{2}} + 1}\newline
|
||||
h_{1} &= \frac{1}{0.3 e^{- 4.8 h_{2} + 9.22 i_{0} + 8.09 i_{1} - 10.24 i_{2}} + 1}\newline
|
||||
h_{2} &= \frac{1}{0.27 e^{4.28 i_{1}} + 1}\newline
|
||||
o_{0} &= \frac{1}{0.68 e^{- 20.86 h_{0} + 11.12 h_{1} + 14.22 i_{0} - 1.96 i_{2}} + 1}\newline
|
||||
\end{align}
|
||||
```
|
||||
Obtain python codes:
|
||||
```python
|
||||
h = np.zeros(3)
|
||||
o = np.zeros(1)
|
||||
h[0] = 1/(2.825013*exp(5.660946*h[1] - 6.083459*h[2] - 3.033361*i[2]) + 1)
|
||||
h[1] = 1/(0.300038*exp(-4.802896*h[2] + 9.215506*i[0] + 8.091845*i[1] - 10.241107*i[2]) + 1)
|
||||
h[2] = 1/(0.269965*exp(4.279962*i[1]) + 1)
|
||||
o[0] = 1/(0.679321*exp(-20.860441*h[0] + 11.122242*h[1] + 14.216276*i[0] - 1.961642*i[2]) + 1)
|
||||
```
|
||||
|
||||
|
||||
## Installation
|
||||
Install `tensorneat` from the GitHub source code:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat import algorithm, genome, problem, common
|
||||
from tensorneat import algorithm, genome, problem
|
||||
from tensorneat.common import ACT
|
||||
|
||||
algorithm = algorithm.NEAT(
|
||||
pop_size=10000,
|
||||
@@ -8,7 +9,8 @@ algorithm = algorithm.NEAT(
|
||||
genome=genome.DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
output_transform=common.ACT.sigmoid,
|
||||
max_nodes=7,
|
||||
output_transform=ACT.sigmoid,
|
||||
),
|
||||
)
|
||||
problem = problem.XOR3d()
|
||||
@@ -25,3 +27,20 @@ state = pipeline.setup()
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
|
||||
# visualize the best individual
|
||||
network = algorithm.genome.network_dict(state, *best)
|
||||
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
|
||||
|
||||
# transform the best individual to latex formula
|
||||
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
|
||||
|
||||
sympy_res = algorithm.genome.sympy_func(
|
||||
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
|
||||
)
|
||||
latex_code = to_latex_code(*sympy_res)
|
||||
print(latex_code)
|
||||
|
||||
# transform the best individual to python code
|
||||
python_code = to_python_code(*sympy_res)
|
||||
print(python_code)
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 18 KiB |
File diff suppressed because one or more lines are too long
@@ -1,13 +1,9 @@
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 创建一个空白的有向图
|
||||
G = nx.DiGraph()
|
||||
|
||||
# 添加边
|
||||
G.add_edge('A', 'B')
|
||||
G.add_edge('A', 'C')
|
||||
G.add_edge('B', 'C')
|
||||
G.add_edge('C', 'D')
|
||||
|
||||
# 绘制有向图
|
||||
h = np.zeros(3)
|
||||
o = np.zeros(4)
|
||||
h[0] = -tanh(0.540486*i[0] + 1.04397*i[1] + 0.58006*i[10] + 0.658223*i[11] - 0.9918*i[12] - 0.01919*i[13] + 0.194062*i[14] + 0.903314*i[15] - 1.906567*i[2] - 1.666336*i[3] + 0.653257*i[4] + 0.580191*i[5] + 0.177264*i[6] + 0.830688*i[7] - 0.855676*i[8] + 0.326538*i[9] + 2.465507)
|
||||
h[1] = -tanh(1.441044*i[0] - 0.606109*i[1] - 0.736058*i[10] + 0.60264*i[11] - 0.837565*i[12] + 2.018719*i[13] + 0.327097*i[14] + 0.098963*i[15] + 0.403485*i[2] - 0.680547*i[3] + 0.349021*i[4] - 1.359364*i[5] + 0.351466*i[6] + 0.450447*i[7] + 2.102749*i[8] + 0.680605*i[9] + 0.593945)
|
||||
h[2] = -tanh(1.350645*i[0] - 0.281682*i[1] + 0.332992*i[10] + 0.703457*i[11] + 1.290286*i[12] - 1.059887*i[13] - 1.114513*i[14] + 0.446127*i[15] + 1.103008*i[2] + 1.080698*i[3] - 0.89471*i[4] + 0.103146*i[5] - 0.828767*i[6] + 0.609362*i[7] - 0.765917*i[8] + 0.047898*i[9] + 0.649254)
|
||||
o[0] = -1.307307*h[0] - 0.985838*h[1] - 0.746408*h[2] + 0.245725
|
||||
o[1] = 0.64947*h[0] + 2.865669*h[1] + 1.185759*h[2] - 1.347174
|
||||
o[2] = 2.030407*h[0] + 1.001914*h[1] - 1.041287*h[2] + 0.301639
|
||||
o[3] = 0.717661*h[0] + 0.653905*h[1] - 1.387949*h[2] - 1.200779
|
||||
226
imgs/xor_network.svg
Normal file
226
imgs/xor_network.svg
Normal file
@@ -0,0 +1,226 @@
|
||||
<?xml version="1.0" encoding="utf-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
|
||||
<metadata>
|
||||
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<cc:Work>
|
||||
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||
<dc:date>2024-07-12T07:46:33.195425</dc:date>
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:creator>
|
||||
<cc:Agent>
|
||||
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
|
||||
</cc:Agent>
|
||||
</dc:creator>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs>
|
||||
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
|
||||
</defs>
|
||||
<g id="figure_1">
|
||||
<g id="patch_1">
|
||||
<path d="M 0 345.6
|
||||
L 460.8 345.6
|
||||
L 460.8 0
|
||||
L 0 0
|
||||
z
|
||||
" style="fill: #ffffff"/>
|
||||
</g>
|
||||
<g id="axes_1">
|
||||
<g id="patch_2">
|
||||
<path d="M 48.095958 312.568974
|
||||
Q 230.401029 244.204573 411.659252 176.232739
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 410.324982 176.09229
|
||||
L 411.659252 176.232739
|
||||
L 410.746331 177.215885
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_3">
|
||||
<path d="M 46.916335 310.412749
|
||||
Q 135.193491 244.204882 222.57622 178.667835
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 221.25622 178.907835
|
||||
L 222.57622 178.667835
|
||||
L 221.97622 179.867835
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_4">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 135.193467 172.8 220.620902 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 219.420902 172.2
|
||||
L 220.620902 172.8
|
||||
L 219.420902 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_5">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 87.590519 172.8 125.415005 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 124.215005 172.2
|
||||
L 125.415005 172.8
|
||||
L 124.215005 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_6">
|
||||
<path d="M 46.916335 35.187251
|
||||
Q 135.193491 101.395118 222.57622 166.932165
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 221.97622 165.732165
|
||||
L 222.57622 166.932165
|
||||
L 221.25622 166.692165
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_7">
|
||||
<path d="M 48.095958 33.031026
|
||||
Q 230.401029 101.395427 411.659252 169.367261
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 410.746331 168.384115
|
||||
L 411.659252 169.367261
|
||||
L 410.324982 169.50771
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_8">
|
||||
<path d="M 47.731321 33.862355
|
||||
Q 182.795689 101.394539 316.860058 168.426723
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 316.055073 167.35341
|
||||
L 316.860058 168.426723
|
||||
L 315.518417 168.426723
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_9">
|
||||
<path d="M 143.85461 172.8
|
||||
Q 182.79713 172.8 220.621616 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 219.421616 172.2
|
||||
L 220.621616 172.8
|
||||
L 219.421616 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_10">
|
||||
<path d="M 143.85461 172.8
|
||||
Q 230.400079 172.8 315.827513 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 314.627513 172.2
|
||||
L 315.827513 172.8
|
||||
L 314.627513 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_11">
|
||||
<path d="M 239.061222 172.8
|
||||
Q 278.003742 172.8 315.828228 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 314.628228 172.2
|
||||
L 315.828228 172.8
|
||||
L 314.628228 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_12">
|
||||
<path d="M 239.061222 172.8
|
||||
Q 325.60669 172.8 411.034125 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 409.834125 172.2
|
||||
L 411.034125 172.8
|
||||
L 409.834125 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_13">
|
||||
<path d="M 334.267833 172.8
|
||||
Q 373.210353 172.8 411.03484 172.8
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 409.83484 172.2
|
||||
L 411.03484 172.8
|
||||
L 409.83484 173.4
|
||||
" clip-path="url(#pbc29be279d)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="PathCollection_1">
|
||||
<path d="M 39.986777 324.270171
|
||||
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
|
||||
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
|
||||
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
|
||||
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
|
||||
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
|
||||
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 181.460254
|
||||
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
|
||||
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
|
||||
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
|
||||
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 38.650337
|
||||
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
|
||||
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
|
||||
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
|
||||
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 135.193388 181.460254
|
||||
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
|
||||
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
|
||||
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
|
||||
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
|
||||
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
|
||||
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
|
||||
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
|
||||
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
|
||||
<path d="M 230.4 181.460254
|
||||
C 232.696726 181.460254 234.899694 180.547755 236.523724 178.923724
|
||||
C 238.147755 177.299694 239.060254 175.096726 239.060254 172.8
|
||||
C 239.060254 170.503274 238.147755 168.300306 236.523724 166.676276
|
||||
C 234.899694 165.052245 232.696726 164.139746 230.4 164.139746
|
||||
C 228.103274 164.139746 225.900306 165.052245 224.276276 166.676276
|
||||
C 222.652245 168.300306 221.739746 170.503274 221.739746 172.8
|
||||
C 221.739746 175.096726 222.652245 177.299694 224.276276 178.923724
|
||||
C 225.900306 180.547755 228.103274 181.460254 230.4 181.460254
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
|
||||
<path d="M 325.606612 181.460254
|
||||
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
|
||||
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
|
||||
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
|
||||
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
|
||||
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
|
||||
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
|
||||
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
|
||||
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #ffffff; stroke: #000000"/>
|
||||
<path d="M 420.813223 181.460254
|
||||
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
|
||||
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
|
||||
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
|
||||
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||
z
|
||||
" clip-path="url(#pbc29be279d)" style="fill: #0000ff; stroke: #000000"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="pbc29be279d">
|
||||
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 9.9 KiB |
91
src/tensorneat/common/sympy_tools.py
Normal file
91
src/tensorneat/common/sympy_tools.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import re
|
||||
import sympy as sp
|
||||
|
||||
def analysis_nodes_exprs(nodes_exprs):
|
||||
input_cnt, hidden_cnt, output_cnt = 0, 0, 0
|
||||
norm_symbols = {}
|
||||
for key in nodes_exprs.keys():
|
||||
if str(key).startswith('i'):
|
||||
input_cnt += 1
|
||||
elif str(key).startswith('h'):
|
||||
hidden_cnt += 1
|
||||
elif str(key).startswith('o'):
|
||||
output_cnt += 1
|
||||
elif str(key).startswith('norm'):
|
||||
norm_symbols[key] = nodes_exprs[key]
|
||||
return input_cnt, hidden_cnt, output_cnt, norm_symbols
|
||||
|
||||
def round_expr(expr, precision=2):
|
||||
"""
|
||||
Round numerical values in a sympy expression to a given precision.
|
||||
"""
|
||||
return expr.xreplace({n: round(n, precision) for n in expr.atoms(sp.Number)})
|
||||
|
||||
|
||||
def replace_variable_names(expression):
|
||||
"""
|
||||
Transform sympy expression to a string with array index that can be used in python code.
|
||||
For example, `o0` will be transformed to `o[0]`.
|
||||
"""
|
||||
expression_str = str(expression)
|
||||
expression_str = re.sub(r"\bo(\d+)\b", r"o[\1]", expression_str)
|
||||
expression_str = re.sub(r"\bh(\d+)\b", r"h[\1]", expression_str)
|
||||
expression_str = re.sub(r"\bi(\d+)\b", r"i[\1]", expression_str)
|
||||
return expression_str
|
||||
|
||||
|
||||
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = "\\begin{align}\n"
|
||||
|
||||
if not use_hidden_nodes:
|
||||
for i in range(output_cnt):
|
||||
expr = output_exprs[i].subs(args_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
else:
|
||||
for i in range(hidden_cnt):
|
||||
symbol = sp.symbols(f"h{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"h_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
for i in range(output_cnt):
|
||||
symbol = sp.symbols(f"o{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
res += "\\end{align}\n"
|
||||
return res
|
||||
|
||||
|
||||
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = ""
|
||||
if not use_hidden_nodes:
|
||||
# pre-allocate space
|
||||
res += f"o = np.zeros({output_cnt})\n"
|
||||
for i in range(output_cnt):
|
||||
expr = output_exprs[i].subs(args_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"o{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
else:
|
||||
# pre-allocate space
|
||||
res += f"h = np.zeros({hidden_cnt})\n"
|
||||
res += f"o = np.zeros({output_cnt})\n"
|
||||
for i in range(hidden_cnt):
|
||||
symbol = sp.symbols(f"h{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"h{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
for i in range(output_cnt):
|
||||
symbol = sp.symbols(f"o{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"o{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
return res
|
||||
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
StatefulBaseClass,
|
||||
hash_array,
|
||||
)
|
||||
from .utils import valid_cnt
|
||||
from .utils import valid_cnt, re_cound_idx
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -160,7 +160,11 @@ class BaseGenome(StatefulBaseClass):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
def network_dict(self, state, nodes, conns, whether_re_cound_idx=True):
|
||||
if whether_re_cound_idx:
|
||||
nodes, conns = re_cound_idx(
|
||||
nodes, conns, self.get_input_idx(), self.get_output_idx()
|
||||
)
|
||||
return {
|
||||
"nodes": self._get_node_dict(state, nodes),
|
||||
"conns": self._get_conn_dict(state, conns),
|
||||
|
||||
@@ -209,7 +209,6 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
@@ -107,3 +108,33 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def re_cound_idx(nodes, conns, input_idx, output_idx):
|
||||
"""
|
||||
Make the key of hidden nodes continuous.
|
||||
Also update the index of connections.
|
||||
"""
|
||||
nodes, conns = jax.device_get((nodes, conns))
|
||||
next_key = max(*input_idx, *output_idx) + 1
|
||||
old2new = {}
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if np.isnan(key):
|
||||
continue
|
||||
if np.in1d(key, input_idx + output_idx):
|
||||
continue
|
||||
old2new[int(key)] = next_key
|
||||
next_key += 1
|
||||
|
||||
new_nodes = nodes.copy()
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if (not np.isnan(key)) and int(key) in old2new:
|
||||
new_nodes[i, 0] = old2new[int(key)]
|
||||
|
||||
new_conns = conns.copy()
|
||||
for i, (i_key, o_key) in enumerate(conns[:, :2]):
|
||||
if (not np.isnan(i_key)) and int(i_key) in old2new:
|
||||
new_conns[i, 0] = old2new[int(i_key)]
|
||||
if (not np.isnan(o_key)) and int(o_key) in old2new:
|
||||
new_conns[i, 1] = old2new[int(o_key)]
|
||||
return new_nodes, new_conns
|
||||
|
||||
Reference in New Issue
Block a user