update recurrent genome

This commit is contained in:
root
2024-07-10 16:27:49 +08:00
parent 1d606eb1c3
commit 649d4b0552
8 changed files with 490 additions and 46 deletions

View File

@@ -1,10 +1,21 @@
import jax, jax.numpy as jnp
from tensorneat.algorithm import NEAT
from tensorneat.algorithm.neat import DefaultGenome
from tensorneat.algorithm.neat import DefaultGenome, RecurrentGenome
key = jax.random.key(0)
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=())
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3))
state = genome.setup()
nodes, conns = genome.initialize(state, key)
print(genome.repr(state, nodes, conns))
inputs = jnp.array([1, 2, 3, 4, 5])
transformed = genome.transform(state, nodes, conns)
outputs = genome.forward(state, transformed, inputs)
print(outputs)
network = genome.network_dict(state, nodes, conns)
print(network)
genome.visualize(network)

16
examples/tmp2.py Normal file
View File

@@ -0,0 +1,16 @@
import jax, jax.numpy as jnp
arr = jnp.ones((10, 10))
a = jnp.array([
[1, 2, 3],
[4, 5, 6]
])
def attach_with_inf(arr, idx):
target_dim = arr.ndim + idx.ndim - 1
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
b = attach_with_inf(arr, a)
print(b)

415
network.svg Normal file
View File

@@ -0,0 +1,415 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2024-07-10T15:27:16.806503</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 345.6
L 460.8 345.6
L 460.8 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 44.79098 308.403612
Q 87.590594 244.204191 129.770035 180.93503
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.887134 183.153831
L 129.770035 180.93503
L 129.215335 185.372632
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_3">
<path d="M 46.916335 239.00779
Q 87.591722 208.50125 127.372682 178.665529
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 122.972682 179.465529
L 127.372682 178.665529
L 125.372682 182.665529
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_4">
<path d="M 48.647998 172.8
Q 87.590519 172.8 125.415005 172.8
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 121.415005 170.8
L 125.415005 172.8
L 121.415005 174.8
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_5">
<path d="M 46.916335 106.59221
Q 87.591722 137.09875 127.372682 166.934471
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.372682 162.934471
L 127.372682 166.934471
L 122.972682 166.134471
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_6">
<path d="M 44.79098 37.196388
Q 87.590594 101.395809 129.770035 164.66497
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 129.215335 160.227368
L 129.770035 164.66497
L 125.887134 162.446169
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_7">
<path d="M 143.30257 175.840943
Q 182.796502 190.651168 221.243586 205.068824
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 218.200516 201.791672
L 221.243586 205.068824
L 216.796023 205.536989
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_8">
<path d="M 143.30257 169.759057
Q 182.796502 154.948832 221.243586 140.531176
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 216.796023 140.063011
L 221.243586 140.531176
L 218.200516 143.808328
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_9">
<path d="M 238.509181 211.543422
Q 278.003113 226.353647 316.450198 240.771303
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 237.494151
L 316.450198 240.771303
L 312.002634 241.239468
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_10">
<path d="M 238.509181 205.461536
Q 278.003113 190.651312 316.450198 176.233655
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 175.765491
L 316.450198 176.233655
L 313.407128 179.510807
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_11">
<path d="M 236.155746 202.027265
Q 278.00531 154.946506 319.112092 108.701376
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 314.959818 110.362286
L 319.112092 108.701376
L 317.949455 113.019741
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_12">
<path d="M 236.155746 143.572735
Q 278.00531 190.653494 319.112092 236.898624
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 317.949455 232.580259
L 319.112092 236.898624
L 314.959818 235.237714
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_13">
<path d="M 238.509181 140.138464
Q 278.003113 154.948688 316.450198 169.366345
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 166.089193
L 316.450198 169.366345
L 312.002634 169.834509
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_14">
<path d="M 238.509181 134.056578
Q 278.003113 119.246353 316.450198 104.828697
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 104.360532
L 316.450198 104.828697
L 313.407128 108.105849
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_15">
<path d="M 334.267833 244.204959
Q 373.210353 244.204959 411.03484 244.204959
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 242.204959
L 411.03484 244.204959
L 407.03484 246.204959
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_16">
<path d="M 332.53617 239.00779
Q 373.211557 208.50125 412.992517 178.665529
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 179.465529
L 412.992517 178.665529
L 410.992517 182.665529
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_17">
<path d="M 330.410815 236.998654
Q 373.210429 172.799232 415.38987 109.530071
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 411.506968 111.748872
L 415.38987 109.530071
L 414.83517 113.967673
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_18">
<path d="M 332.53617 177.997169
Q 373.211557 208.503709 412.992517 238.339429
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 234.339429
L 412.992517 238.339429
L 408.592517 237.539429
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_19">
<path d="M 334.267833 172.8
Q 373.210353 172.8 411.03484 172.8
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 170.8
L 411.03484 172.8
L 407.03484 174.8
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_20">
<path d="M 332.53617 167.602831
Q 373.211557 137.096291 412.992517 107.260571
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 108.060571
L 412.992517 107.260571
L 410.992517 111.260571
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_21">
<path d="M 330.410815 108.601346
Q 373.210429 172.800768 415.38987 236.069929
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 414.83517 231.632327
L 415.38987 236.069929
L 411.506968 233.851128
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_22">
<path d="M 332.53617 106.59221
Q 373.211557 137.09875 412.992517 166.934471
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 162.934471
L 412.992517 166.934471
L 408.592517 166.134471
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_23">
<path d="M 334.267833 101.395041
Q 373.210353 101.395041 411.03484 101.395041
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 99.395041
L 411.03484 101.395041
L 407.03484 103.395041
z
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="PathCollection_1">
<path d="M 39.986777 324.270171
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 252.865213
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
C 48.647031 241.908232 47.734532 239.705265 46.110501 238.081234
C 44.486471 236.457204 42.283503 235.544705 39.986777 235.544705
C 37.690051 235.544705 35.487083 236.457204 33.863053 238.081234
C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 110.055295
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
C 48.647031 99.098315 47.734532 96.895348 46.110501 95.271317
C 44.486471 93.647286 42.283503 92.734787 39.986777 92.734787
C 37.690051 92.734787 35.487083 93.647286 33.863053 95.271317
C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 135.193388 181.460254
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 217.162733
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
C 239.060254 206.205753 238.147755 204.002786 236.523724 202.378755
C 234.899694 200.754724 232.696726 199.842225 230.4 199.842225
C 228.103274 199.842225 225.900306 200.754724 224.276276 202.378755
C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 145.757775
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
C 239.060254 134.800794 238.147755 132.597827 236.523724 130.973796
C 234.899694 129.349766 232.696726 128.437267 230.4 128.437267
C 228.103274 128.437267 225.900306 129.349766 224.276276 130.973796
C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 252.865213
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
C 334.266866 241.908232 333.354367 239.705265 331.730336 238.081234
C 330.106305 236.457204 327.903338 235.544705 325.606612 235.544705
C 323.309885 235.544705 321.106918 236.457204 319.482887 238.081234
C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 181.460254
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 110.055295
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
C 334.266866 99.098315 333.354367 96.895348 331.730336 95.271317
C 330.106305 93.647286 327.903338 92.734787 325.606612 92.734787
C 323.309885 92.734787 321.106918 93.647286 319.482887 95.271317
C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 252.865213
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
C 429.473477 241.908232 428.560978 239.705265 426.936947 238.081234
C 425.312917 236.457204 423.109949 235.544705 420.813223 235.544705
C 418.516497 235.544705 416.313529 236.457204 414.689499 238.081234
C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 110.055295
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
C 429.473477 99.098315 428.560978 96.895348 426.936947 95.271317
C 425.312917 93.647286 423.109949 92.734787 420.813223 92.734787
C 418.516497 92.734787 416.313529 93.647286 414.689499 95.271317
C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
z
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
</g>
</g>
</g>
<defs>
<clipPath id="p80fa8c6777">
<rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -34,9 +34,6 @@ class BaseGene(StatefulBaseClass):
def forward(self, state, attrs, inputs):
raise NotImplementedError
def update_by_batch(self, state, attrs, batch_inputs):
raise NotImplementedError
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)

View File

@@ -64,7 +64,7 @@ class BaseGenome(StatefulBaseClass):
all_init_conns_in_idx.append(in_idx)
all_init_conns_out_idx.append(out_idx)
all_init_nodes.extend(in_layer)
all_init_nodes.extend(layer_indices[-1])
all_init_nodes.extend(layer_indices[-1]) # output layer
if max_nodes < len(all_init_nodes):
raise ValueError(

View File

@@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome):
def cond_fun(carry):
values, idx = carry
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
return (idx < self.max_nodes) & (
cal_seqs[idx] != I_INF
) # not out of bounds and next node exists
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def input_node():
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
new_values = values.at[i].set(z)
return new_values
return values
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
# calculate nodes
z = self.node_gene.forward(
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(i, self.output_idx),
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value
new_values = values.at[i].set(z)
return new_values

View File

@@ -1,12 +1,13 @@
from typing import Callable
import jax, jax.numpy as jnp
import jax
from jax import vmap, numpy as jnp
from .utils import unflatten_conns
from . import BaseGenome
from .base import BaseGenome
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from ..gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover
from tensorneat.common import attach_with_inf
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
@@ -17,14 +18,17 @@ class RecurrentGenome(BaseGenome):
self,
num_inputs: int,
num_outputs: int,
max_nodes = 50,
max_conns = 100,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
output_transform=None,
input_transform=None,
init_hidden_layers=(),
activate_time=10,
output_transform: Callable = None,
):
super().__init__(
num_inputs,
@@ -35,29 +39,25 @@ class RecurrentGenome(BaseGenome):
conn_gene,
mutation,
crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
return nodes, conns, u_conns
def restore(self, state, transformed):
def forward(self, state, transformed, inputs):
nodes, conns, u_conns = transformed
return nodes, conns
def forward(self, state, inputs, transformed):
nodes, conns = transformed
vals = jnp.full((self.max_nodes,), jnp.nan)
nodes_attrs = nodes[:, 1:] # remove index
nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = vmap(extract_conn_attrs)(conns)
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
def body_func(_, values):
@@ -65,14 +65,14 @@ class RecurrentGenome(BaseGenome):
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
jax.vmap(self.conn_gene.forward, in_axes=(None, 1, None)),
in_axes=(None, 1, 0),
)(state, conns, values)
node_ins = vmap(
vmap(self.conn_gene.forward, in_axes=(None, 0, None)),
in_axes=(None, 0, 0),
)(state, expand_conns_attrs, values)
# calculate nodes
is_output_nodes = jnp.isin(jnp.arange(self.max_nodes), self.output_idx)
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx)
values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
state, nodes_attrs, node_ins.T, is_output_nodes
)
@@ -87,3 +87,6 @@ class RecurrentGenome(BaseGenome):
def sympy_func(self, state, network, precision=3):
raise ValueError("Sympy function is not supported for Recurrent Network!")
def visualize(self, network):
raise ValueError("Visualize function is not supported for Recurrent Network!")

View File

@@ -6,12 +6,11 @@ from jax import numpy as jnp, Array, jit, vmap
I_INF = np.iinfo(jnp.int32).max # infinite int
# TODO: strange implementation
def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim
expand_idx = jnp.expand_dims(
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
)
target_dim = arr.ndim + idx.ndim - 1
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])