update recurrent genome
This commit is contained in:
@@ -1,10 +1,21 @@
|
|||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
from tensorneat.algorithm import NEAT
|
from tensorneat.algorithm import NEAT
|
||||||
from tensorneat.algorithm.neat import DefaultGenome
|
from tensorneat.algorithm.neat import DefaultGenome, RecurrentGenome
|
||||||
|
|
||||||
key = jax.random.key(0)
|
key = jax.random.key(0)
|
||||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=())
|
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3))
|
||||||
state = genome.setup()
|
state = genome.setup()
|
||||||
nodes, conns = genome.initialize(state, key)
|
nodes, conns = genome.initialize(state, key)
|
||||||
print(genome.repr(state, nodes, conns))
|
print(genome.repr(state, nodes, conns))
|
||||||
|
|
||||||
|
inputs = jnp.array([1, 2, 3, 4, 5])
|
||||||
|
transformed = genome.transform(state, nodes, conns)
|
||||||
|
outputs = genome.forward(state, transformed, inputs)
|
||||||
|
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
network = genome.network_dict(state, nodes, conns)
|
||||||
|
print(network)
|
||||||
|
|
||||||
|
genome.visualize(network)
|
||||||
|
|||||||
16
examples/tmp2.py
Normal file
16
examples/tmp2.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
|
arr = jnp.ones((10, 10))
|
||||||
|
a = jnp.array([
|
||||||
|
[1, 2, 3],
|
||||||
|
[4, 5, 6]
|
||||||
|
])
|
||||||
|
|
||||||
|
def attach_with_inf(arr, idx):
|
||||||
|
target_dim = arr.ndim + idx.ndim - 1
|
||||||
|
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||||
|
|
||||||
|
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
|
||||||
|
|
||||||
|
b = attach_with_inf(arr, a)
|
||||||
|
print(b)
|
||||||
415
network.svg
Normal file
415
network.svg
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8" standalone="no"?>
|
||||||
|
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||||
|
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||||
|
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
|
||||||
|
<metadata>
|
||||||
|
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||||
|
<cc:Work>
|
||||||
|
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||||
|
<dc:date>2024-07-10T15:27:16.806503</dc:date>
|
||||||
|
<dc:format>image/svg+xml</dc:format>
|
||||||
|
<dc:creator>
|
||||||
|
<cc:Agent>
|
||||||
|
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
|
||||||
|
</cc:Agent>
|
||||||
|
</dc:creator>
|
||||||
|
</cc:Work>
|
||||||
|
</rdf:RDF>
|
||||||
|
</metadata>
|
||||||
|
<defs>
|
||||||
|
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
|
||||||
|
</defs>
|
||||||
|
<g id="figure_1">
|
||||||
|
<g id="patch_1">
|
||||||
|
<path d="M 0 345.6
|
||||||
|
L 460.8 345.6
|
||||||
|
L 460.8 0
|
||||||
|
L 0 0
|
||||||
|
z
|
||||||
|
" style="fill: #ffffff"/>
|
||||||
|
</g>
|
||||||
|
<g id="axes_1">
|
||||||
|
<g id="patch_2">
|
||||||
|
<path d="M 44.79098 308.403612
|
||||||
|
Q 87.590594 244.204191 129.770035 180.93503
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 125.887134 183.153831
|
||||||
|
L 129.770035 180.93503
|
||||||
|
L 129.215335 185.372632
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_3">
|
||||||
|
<path d="M 46.916335 239.00779
|
||||||
|
Q 87.591722 208.50125 127.372682 178.665529
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 122.972682 179.465529
|
||||||
|
L 127.372682 178.665529
|
||||||
|
L 125.372682 182.665529
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_4">
|
||||||
|
<path d="M 48.647998 172.8
|
||||||
|
Q 87.590519 172.8 125.415005 172.8
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 121.415005 170.8
|
||||||
|
L 125.415005 172.8
|
||||||
|
L 121.415005 174.8
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_5">
|
||||||
|
<path d="M 46.916335 106.59221
|
||||||
|
Q 87.591722 137.09875 127.372682 166.934471
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 125.372682 162.934471
|
||||||
|
L 127.372682 166.934471
|
||||||
|
L 122.972682 166.134471
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_6">
|
||||||
|
<path d="M 44.79098 37.196388
|
||||||
|
Q 87.590594 101.395809 129.770035 164.66497
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 129.215335 160.227368
|
||||||
|
L 129.770035 164.66497
|
||||||
|
L 125.887134 162.446169
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_7">
|
||||||
|
<path d="M 143.30257 175.840943
|
||||||
|
Q 182.796502 190.651168 221.243586 205.068824
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 218.200516 201.791672
|
||||||
|
L 221.243586 205.068824
|
||||||
|
L 216.796023 205.536989
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_8">
|
||||||
|
<path d="M 143.30257 169.759057
|
||||||
|
Q 182.796502 154.948832 221.243586 140.531176
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 216.796023 140.063011
|
||||||
|
L 221.243586 140.531176
|
||||||
|
L 218.200516 143.808328
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_9">
|
||||||
|
<path d="M 238.509181 211.543422
|
||||||
|
Q 278.003113 226.353647 316.450198 240.771303
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 313.407128 237.494151
|
||||||
|
L 316.450198 240.771303
|
||||||
|
L 312.002634 241.239468
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_10">
|
||||||
|
<path d="M 238.509181 205.461536
|
||||||
|
Q 278.003113 190.651312 316.450198 176.233655
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 312.002634 175.765491
|
||||||
|
L 316.450198 176.233655
|
||||||
|
L 313.407128 179.510807
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_11">
|
||||||
|
<path d="M 236.155746 202.027265
|
||||||
|
Q 278.00531 154.946506 319.112092 108.701376
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 314.959818 110.362286
|
||||||
|
L 319.112092 108.701376
|
||||||
|
L 317.949455 113.019741
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_12">
|
||||||
|
<path d="M 236.155746 143.572735
|
||||||
|
Q 278.00531 190.653494 319.112092 236.898624
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 317.949455 232.580259
|
||||||
|
L 319.112092 236.898624
|
||||||
|
L 314.959818 235.237714
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_13">
|
||||||
|
<path d="M 238.509181 140.138464
|
||||||
|
Q 278.003113 154.948688 316.450198 169.366345
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 313.407128 166.089193
|
||||||
|
L 316.450198 169.366345
|
||||||
|
L 312.002634 169.834509
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_14">
|
||||||
|
<path d="M 238.509181 134.056578
|
||||||
|
Q 278.003113 119.246353 316.450198 104.828697
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 312.002634 104.360532
|
||||||
|
L 316.450198 104.828697
|
||||||
|
L 313.407128 108.105849
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_15">
|
||||||
|
<path d="M 334.267833 244.204959
|
||||||
|
Q 373.210353 244.204959 411.03484 244.204959
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 407.03484 242.204959
|
||||||
|
L 411.03484 244.204959
|
||||||
|
L 407.03484 246.204959
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_16">
|
||||||
|
<path d="M 332.53617 239.00779
|
||||||
|
Q 373.211557 208.50125 412.992517 178.665529
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 408.592517 179.465529
|
||||||
|
L 412.992517 178.665529
|
||||||
|
L 410.992517 182.665529
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_17">
|
||||||
|
<path d="M 330.410815 236.998654
|
||||||
|
Q 373.210429 172.799232 415.38987 109.530071
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 411.506968 111.748872
|
||||||
|
L 415.38987 109.530071
|
||||||
|
L 414.83517 113.967673
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_18">
|
||||||
|
<path d="M 332.53617 177.997169
|
||||||
|
Q 373.211557 208.503709 412.992517 238.339429
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 410.992517 234.339429
|
||||||
|
L 412.992517 238.339429
|
||||||
|
L 408.592517 237.539429
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_19">
|
||||||
|
<path d="M 334.267833 172.8
|
||||||
|
Q 373.210353 172.8 411.03484 172.8
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 407.03484 170.8
|
||||||
|
L 411.03484 172.8
|
||||||
|
L 407.03484 174.8
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_20">
|
||||||
|
<path d="M 332.53617 167.602831
|
||||||
|
Q 373.211557 137.096291 412.992517 107.260571
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 408.592517 108.060571
|
||||||
|
L 412.992517 107.260571
|
||||||
|
L 410.992517 111.260571
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_21">
|
||||||
|
<path d="M 330.410815 108.601346
|
||||||
|
Q 373.210429 172.800768 415.38987 236.069929
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 414.83517 231.632327
|
||||||
|
L 415.38987 236.069929
|
||||||
|
L 411.506968 233.851128
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_22">
|
||||||
|
<path d="M 332.53617 106.59221
|
||||||
|
Q 373.211557 137.09875 412.992517 166.934471
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 410.992517 162.934471
|
||||||
|
L 412.992517 166.934471
|
||||||
|
L 408.592517 166.134471
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="patch_23">
|
||||||
|
<path d="M 334.267833 101.395041
|
||||||
|
Q 373.210353 101.395041 411.03484 101.395041
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||||
|
<path d="M 407.03484 99.395041
|
||||||
|
L 411.03484 101.395041
|
||||||
|
L 407.03484 103.395041
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="stroke: #000000; stroke-linecap: round"/>
|
||||||
|
</g>
|
||||||
|
<g id="PathCollection_1">
|
||||||
|
<path d="M 39.986777 324.270171
|
||||||
|
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
|
||||||
|
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
|
||||||
|
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
|
||||||
|
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
|
||||||
|
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
|
||||||
|
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||||
|
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||||
|
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 39.986777 252.865213
|
||||||
|
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
|
||||||
|
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
|
||||||
|
C 48.647031 241.908232 47.734532 239.705265 46.110501 238.081234
|
||||||
|
C 44.486471 236.457204 42.283503 235.544705 39.986777 235.544705
|
||||||
|
C 37.690051 235.544705 35.487083 236.457204 33.863053 238.081234
|
||||||
|
C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
|
||||||
|
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
|
||||||
|
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 39.986777 181.460254
|
||||||
|
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||||
|
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||||
|
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
|
||||||
|
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
|
||||||
|
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
|
||||||
|
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||||
|
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||||
|
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 39.986777 110.055295
|
||||||
|
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
|
||||||
|
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
|
||||||
|
C 48.647031 99.098315 47.734532 96.895348 46.110501 95.271317
|
||||||
|
C 44.486471 93.647286 42.283503 92.734787 39.986777 92.734787
|
||||||
|
C 37.690051 92.734787 35.487083 93.647286 33.863053 95.271317
|
||||||
|
C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
|
||||||
|
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
|
||||||
|
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 39.986777 38.650337
|
||||||
|
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||||
|
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||||
|
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
|
||||||
|
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
|
||||||
|
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
|
||||||
|
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||||
|
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||||
|
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 135.193388 181.460254
|
||||||
|
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
|
||||||
|
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
|
||||||
|
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
|
||||||
|
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
|
||||||
|
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
|
||||||
|
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
|
||||||
|
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
|
||||||
|
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 230.4 217.162733
|
||||||
|
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
|
||||||
|
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
|
||||||
|
C 239.060254 206.205753 238.147755 204.002786 236.523724 202.378755
|
||||||
|
C 234.899694 200.754724 232.696726 199.842225 230.4 199.842225
|
||||||
|
C 228.103274 199.842225 225.900306 200.754724 224.276276 202.378755
|
||||||
|
C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
|
||||||
|
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
|
||||||
|
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 230.4 145.757775
|
||||||
|
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
|
||||||
|
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
|
||||||
|
C 239.060254 134.800794 238.147755 132.597827 236.523724 130.973796
|
||||||
|
C 234.899694 129.349766 232.696726 128.437267 230.4 128.437267
|
||||||
|
C 228.103274 128.437267 225.900306 129.349766 224.276276 130.973796
|
||||||
|
C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
|
||||||
|
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
|
||||||
|
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 325.606612 252.865213
|
||||||
|
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
|
||||||
|
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
|
||||||
|
C 334.266866 241.908232 333.354367 239.705265 331.730336 238.081234
|
||||||
|
C 330.106305 236.457204 327.903338 235.544705 325.606612 235.544705
|
||||||
|
C 323.309885 235.544705 321.106918 236.457204 319.482887 238.081234
|
||||||
|
C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
|
||||||
|
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
|
||||||
|
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 325.606612 181.460254
|
||||||
|
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
|
||||||
|
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
|
||||||
|
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
|
||||||
|
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
|
||||||
|
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
|
||||||
|
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
|
||||||
|
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
|
||||||
|
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 325.606612 110.055295
|
||||||
|
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
|
||||||
|
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
|
||||||
|
C 334.266866 99.098315 333.354367 96.895348 331.730336 95.271317
|
||||||
|
C 330.106305 93.647286 327.903338 92.734787 325.606612 92.734787
|
||||||
|
C 323.309885 92.734787 321.106918 93.647286 319.482887 95.271317
|
||||||
|
C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
|
||||||
|
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
|
||||||
|
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 420.813223 252.865213
|
||||||
|
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
|
||||||
|
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
|
||||||
|
C 429.473477 241.908232 428.560978 239.705265 426.936947 238.081234
|
||||||
|
C 425.312917 236.457204 423.109949 235.544705 420.813223 235.544705
|
||||||
|
C 418.516497 235.544705 416.313529 236.457204 414.689499 238.081234
|
||||||
|
C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
|
||||||
|
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
|
||||||
|
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 420.813223 181.460254
|
||||||
|
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||||
|
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||||
|
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
|
||||||
|
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
|
||||||
|
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
|
||||||
|
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||||
|
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||||
|
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
<path d="M 420.813223 110.055295
|
||||||
|
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
|
||||||
|
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
|
||||||
|
C 429.473477 99.098315 428.560978 96.895348 426.936947 95.271317
|
||||||
|
C 425.312917 93.647286 423.109949 92.734787 420.813223 92.734787
|
||||||
|
C 418.516497 92.734787 416.313529 93.647286 414.689499 95.271317
|
||||||
|
C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
|
||||||
|
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
|
||||||
|
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
|
||||||
|
z
|
||||||
|
" clip-path="url(#p80fa8c6777)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<defs>
|
||||||
|
<clipPath id="p80fa8c6777">
|
||||||
|
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||||
|
</clipPath>
|
||||||
|
</defs>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 18 KiB |
@@ -34,9 +34,6 @@ class BaseGene(StatefulBaseClass):
|
|||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def update_by_batch(self, state, attrs, batch_inputs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def length(self):
|
def length(self):
|
||||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
all_init_conns_in_idx.append(in_idx)
|
all_init_conns_in_idx.append(in_idx)
|
||||||
all_init_conns_out_idx.append(out_idx)
|
all_init_conns_out_idx.append(out_idx)
|
||||||
all_init_nodes.extend(in_layer)
|
all_init_nodes.extend(in_layer)
|
||||||
all_init_nodes.extend(layer_indices[-1])
|
all_init_nodes.extend(layer_indices[-1]) # output layer
|
||||||
|
|
||||||
if max_nodes < len(all_init_nodes):
|
if max_nodes < len(all_init_nodes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def cond_fun(carry):
|
def cond_fun(carry):
|
||||||
values, idx = carry
|
values, idx = carry
|
||||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
return (idx < self.max_nodes) & (
|
||||||
|
cal_seqs[idx] != I_INF
|
||||||
|
) # not out of bounds and next node exists
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
values, idx = carry
|
values, idx = carry
|
||||||
i = cal_seqs[idx]
|
i = cal_seqs[idx]
|
||||||
|
|
||||||
def input_node():
|
def input_node():
|
||||||
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
return values
|
||||||
new_values = values.at[i].set(z)
|
|
||||||
return new_values
|
|
||||||
|
|
||||||
def otherwise():
|
def otherwise():
|
||||||
|
# calculate connections
|
||||||
conn_indices = u_conns[:, i]
|
conn_indices = u_conns[:, i]
|
||||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||||
state, hit_attrs, values
|
state, hit_attrs, values
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate nodes
|
||||||
z = self.node_gene.forward(
|
z = self.node_gene.forward(
|
||||||
state,
|
state,
|
||||||
nodes_attrs[i],
|
nodes_attrs[i],
|
||||||
ins,
|
ins,
|
||||||
is_output_node=jnp.isin(i, self.output_idx),
|
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# set new value
|
||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from typing import Callable
|
import jax
|
||||||
|
from jax import vmap, numpy as jnp
|
||||||
import jax, jax.numpy as jnp
|
|
||||||
from .utils import unflatten_conns
|
from .utils import unflatten_conns
|
||||||
|
|
||||||
from . import BaseGenome
|
from .base import BaseGenome
|
||||||
|
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||||
|
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||||
from .operations import DefaultMutation, DefaultCrossover
|
|
||||||
|
|
||||||
|
from tensorneat.common import attach_with_inf
|
||||||
|
|
||||||
class RecurrentGenome(BaseGenome):
|
class RecurrentGenome(BaseGenome):
|
||||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||||
@@ -17,14 +18,17 @@ class RecurrentGenome(BaseGenome):
|
|||||||
self,
|
self,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
num_outputs: int,
|
num_outputs: int,
|
||||||
max_nodes = 50,
|
max_nodes=50,
|
||||||
max_conns = 100,
|
max_conns=100,
|
||||||
node_gene=DefaultNodeGene(),
|
node_gene=DefaultNodeGene(),
|
||||||
conn_gene=DefaultConnGene(),
|
conn_gene=DefaultConnGene(),
|
||||||
mutation=DefaultMutation(),
|
mutation=DefaultMutation(),
|
||||||
crossover=DefaultCrossover(),
|
crossover=DefaultCrossover(),
|
||||||
|
distance=DefaultDistance(),
|
||||||
|
output_transform=None,
|
||||||
|
input_transform=None,
|
||||||
|
init_hidden_layers=(),
|
||||||
activate_time=10,
|
activate_time=10,
|
||||||
output_transform: Callable = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inputs,
|
num_inputs,
|
||||||
@@ -35,29 +39,25 @@ class RecurrentGenome(BaseGenome):
|
|||||||
conn_gene,
|
conn_gene,
|
||||||
mutation,
|
mutation,
|
||||||
crossover,
|
crossover,
|
||||||
|
distance,
|
||||||
|
output_transform,
|
||||||
|
input_transform,
|
||||||
|
init_hidden_layers,
|
||||||
)
|
)
|
||||||
self.activate_time = activate_time
|
self.activate_time = activate_time
|
||||||
|
|
||||||
if output_transform is not None:
|
|
||||||
try:
|
|
||||||
_ = output_transform(jnp.zeros(num_outputs))
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Output transform function failed: {e}")
|
|
||||||
self.output_transform = output_transform
|
|
||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
return nodes, conns, u_conns
|
return nodes, conns, u_conns
|
||||||
|
|
||||||
def restore(self, state, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
nodes, conns, u_conns = transformed
|
nodes, conns, u_conns = transformed
|
||||||
return nodes, conns
|
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
|
||||||
nodes, conns = transformed
|
|
||||||
|
|
||||||
vals = jnp.full((self.max_nodes,), jnp.nan)
|
vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||||
nodes_attrs = nodes[:, 1:] # remove index
|
|
||||||
|
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||||
|
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||||
|
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
|
||||||
|
|
||||||
def body_func(_, values):
|
def body_func(_, values):
|
||||||
|
|
||||||
@@ -65,14 +65,14 @@ class RecurrentGenome(BaseGenome):
|
|||||||
values = values.at[self.input_idx].set(inputs)
|
values = values.at[self.input_idx].set(inputs)
|
||||||
|
|
||||||
# calculate connections
|
# calculate connections
|
||||||
node_ins = jax.vmap(
|
node_ins = vmap(
|
||||||
jax.vmap(self.conn_gene.forward, in_axes=(None, 1, None)),
|
vmap(self.conn_gene.forward, in_axes=(None, 0, None)),
|
||||||
in_axes=(None, 1, 0),
|
in_axes=(None, 0, 0),
|
||||||
)(state, conns, values)
|
)(state, expand_conns_attrs, values)
|
||||||
|
|
||||||
# calculate nodes
|
# calculate nodes
|
||||||
is_output_nodes = jnp.isin(jnp.arange(self.max_nodes), self.output_idx)
|
is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx)
|
||||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
|
values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
|
||||||
state, nodes_attrs, node_ins.T, is_output_nodes
|
state, nodes_attrs, node_ins.T, is_output_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,3 +87,6 @@ class RecurrentGenome(BaseGenome):
|
|||||||
|
|
||||||
def sympy_func(self, state, network, precision=3):
|
def sympy_func(self, state, network, precision=3):
|
||||||
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
||||||
|
|
||||||
|
def visualize(self, network):
|
||||||
|
raise ValueError("Visualize function is not supported for Recurrent Network!")
|
||||||
|
|||||||
@@ -6,12 +6,11 @@ from jax import numpy as jnp, Array, jit, vmap
|
|||||||
|
|
||||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||||
|
|
||||||
# TODO: strange implementation
|
|
||||||
def attach_with_inf(arr, idx):
|
def attach_with_inf(arr, idx):
|
||||||
expand_size = arr.ndim - idx.ndim
|
target_dim = arr.ndim + idx.ndim - 1
|
||||||
expand_idx = jnp.expand_dims(
|
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||||
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
|
|
||||||
)
|
|
||||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user