modify NEAT package; successfully run xor example
This commit is contained in:
0
Pipeline 20240711012327.pkl
Normal file
0
Pipeline 20240711012327.pkl
Normal file
@@ -1,43 +1,38 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.genome import DefaultGenome, DefaultNodeGene, DefaultMutation
|
||||
from tensorneat.problem.func_fit import XOR3d
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
# activation_options=(Act.tanh,),
|
||||
activation_options=ACT_ALL,
|
||||
aggregation_default=Agg.sum,
|
||||
# aggregation_options=(Agg.sum,),
|
||||
aggregation_options=AGG_ALL,
|
||||
),
|
||||
output_transform=Act.standard_sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.01,
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=Act.tanh,
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=Agg.sum,
|
||||
),
|
||||
output_transform=Act.standard_sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-3,
|
||||
generation_limit=500,
|
||||
fitness_target=-1e-8,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
@@ -47,4 +42,3 @@ if __name__ == "__main__":
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
pipeline.save(state=state)
|
||||
|
||||
120
network.svg
120
network.svg
@@ -6,7 +6,7 @@
|
||||
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<cc:Work>
|
||||
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||
<dc:date>2024-07-10T16:50:19.947855</dc:date>
|
||||
<dc:date>2024-07-10T19:47:34.359228</dc:date>
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:creator>
|
||||
<cc:Agent>
|
||||
@@ -32,222 +32,222 @@ z
|
||||
<g id="patch_2">
|
||||
<path d="M 44.79098 308.403612
|
||||
Q 87.590594 244.204191 129.770035 180.93503
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 125.887134 183.153831
|
||||
L 129.770035 180.93503
|
||||
L 129.215335 185.372632
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_3">
|
||||
<path d="M 46.916335 239.00779
|
||||
Q 87.591722 208.50125 127.372682 178.665529
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 122.972682 179.465529
|
||||
L 127.372682 178.665529
|
||||
L 125.372682 182.665529
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_4">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 87.590519 172.8 125.415005 172.8
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 121.415005 170.8
|
||||
L 125.415005 172.8
|
||||
L 121.415005 174.8
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_5">
|
||||
<path d="M 46.916335 106.59221
|
||||
Q 87.591722 137.09875 127.372682 166.934471
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 125.372682 162.934471
|
||||
L 127.372682 166.934471
|
||||
L 122.972682 166.134471
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_6">
|
||||
<path d="M 44.79098 37.196388
|
||||
Q 87.590594 101.395809 129.770035 164.66497
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 129.215335 160.227368
|
||||
L 129.770035 164.66497
|
||||
L 125.887134 162.446169
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_7">
|
||||
<path d="M 143.30257 175.840943
|
||||
Q 182.796502 190.651168 221.243586 205.068824
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 218.200516 201.791672
|
||||
L 221.243586 205.068824
|
||||
L 216.796023 205.536989
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_8">
|
||||
<path d="M 143.30257 169.759057
|
||||
Q 182.796502 154.948832 221.243586 140.531176
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 216.796023 140.063011
|
||||
L 221.243586 140.531176
|
||||
L 218.200516 143.808328
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_9">
|
||||
<path d="M 238.509181 211.543422
|
||||
Q 278.003113 226.353647 316.450198 240.771303
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 313.407128 237.494151
|
||||
L 316.450198 240.771303
|
||||
L 312.002634 241.239468
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_10">
|
||||
<path d="M 238.509181 205.461536
|
||||
Q 278.003113 190.651312 316.450198 176.233655
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 312.002634 175.765491
|
||||
L 316.450198 176.233655
|
||||
L 313.407128 179.510807
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_11">
|
||||
<path d="M 236.155746 202.027265
|
||||
Q 278.00531 154.946506 319.112092 108.701376
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 314.959818 110.362286
|
||||
L 319.112092 108.701376
|
||||
L 317.949455 113.019741
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_12">
|
||||
<path d="M 236.155746 143.572735
|
||||
Q 278.00531 190.653494 319.112092 236.898624
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 317.949455 232.580259
|
||||
L 319.112092 236.898624
|
||||
L 314.959818 235.237714
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_13">
|
||||
<path d="M 238.509181 140.138464
|
||||
Q 278.003113 154.948688 316.450198 169.366345
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 313.407128 166.089193
|
||||
L 316.450198 169.366345
|
||||
L 312.002634 169.834509
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_14">
|
||||
<path d="M 238.509181 134.056578
|
||||
Q 278.003113 119.246353 316.450198 104.828697
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 312.002634 104.360532
|
||||
L 316.450198 104.828697
|
||||
L 313.407128 108.105849
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_15">
|
||||
<path d="M 334.267833 244.204959
|
||||
Q 373.210353 244.204959 411.03484 244.204959
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 242.204959
|
||||
L 411.03484 244.204959
|
||||
L 407.03484 246.204959
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_16">
|
||||
<path d="M 332.53617 239.00779
|
||||
Q 373.211557 208.50125 412.992517 178.665529
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 408.592517 179.465529
|
||||
L 412.992517 178.665529
|
||||
L 410.992517 182.665529
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_17">
|
||||
<path d="M 330.410815 236.998654
|
||||
Q 373.210429 172.799232 415.38987 109.530071
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 411.506968 111.748872
|
||||
L 415.38987 109.530071
|
||||
L 414.83517 113.967673
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_18">
|
||||
<path d="M 332.53617 177.997169
|
||||
Q 373.211557 208.503709 412.992517 238.339429
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 410.992517 234.339429
|
||||
L 412.992517 238.339429
|
||||
L 408.592517 237.539429
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_19">
|
||||
<path d="M 334.267833 172.8
|
||||
Q 373.210353 172.8 411.03484 172.8
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 170.8
|
||||
L 411.03484 172.8
|
||||
L 407.03484 174.8
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_20">
|
||||
<path d="M 332.53617 167.602831
|
||||
Q 373.211557 137.096291 412.992517 107.260571
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 408.592517 108.060571
|
||||
L 412.992517 107.260571
|
||||
L 410.992517 111.260571
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_21">
|
||||
<path d="M 330.410815 108.601346
|
||||
Q 373.210429 172.800768 415.38987 236.069929
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 414.83517 231.632327
|
||||
L 415.38987 236.069929
|
||||
L 411.506968 233.851128
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_22">
|
||||
<path d="M 332.53617 106.59221
|
||||
Q 373.211557 137.09875 412.992517 166.934471
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 410.992517 162.934471
|
||||
L 412.992517 166.934471
|
||||
L 408.592517 166.134471
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_23">
|
||||
<path d="M 334.267833 101.395041
|
||||
Q 373.210353 101.395041 411.03484 101.395041
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 99.395041
|
||||
L 411.03484 101.395041
|
||||
L 407.03484 103.395041
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="PathCollection_1">
|
||||
<path d="M 39.986777 324.270171
|
||||
@@ -260,7 +260,7 @@ C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 252.865213
|
||||
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
|
||||
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
|
||||
@@ -271,7 +271,7 @@ C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
|
||||
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
|
||||
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 181.460254
|
||||
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||
@@ -282,7 +282,7 @@ C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 110.055295
|
||||
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
|
||||
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
|
||||
@@ -293,7 +293,7 @@ C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
|
||||
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
|
||||
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 38.650337
|
||||
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||
@@ -304,7 +304,7 @@ C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 135.193388 181.460254
|
||||
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
|
||||
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
|
||||
@@ -315,7 +315,7 @@ C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
|
||||
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
|
||||
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 230.4 217.162733
|
||||
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
|
||||
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
|
||||
@@ -326,7 +326,7 @@ C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
|
||||
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
|
||||
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 230.4 145.757775
|
||||
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
|
||||
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
|
||||
@@ -337,7 +337,7 @@ C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
|
||||
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
|
||||
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 252.865213
|
||||
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
|
||||
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
|
||||
@@ -348,7 +348,7 @@ C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
|
||||
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
|
||||
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 181.460254
|
||||
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
|
||||
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
|
||||
@@ -359,7 +359,7 @@ C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
|
||||
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
|
||||
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 110.055295
|
||||
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
|
||||
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
|
||||
@@ -370,7 +370,7 @@ C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
|
||||
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
|
||||
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 252.865213
|
||||
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
|
||||
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
|
||||
@@ -381,7 +381,7 @@ C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
|
||||
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
|
||||
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 181.460254
|
||||
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||
@@ -392,7 +392,7 @@ C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 110.055295
|
||||
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
|
||||
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
|
||||
@@ -403,12 +403,12 @@ C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
|
||||
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
|
||||
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
|
||||
z
|
||||
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="p8fe09283f8">
|
||||
<clipPath id="p572566e0dc">
|
||||
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
|
||||
|
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
@@ -14,13 +14,11 @@ class BaseAlgorithm(StatefulBaseClass):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
def show_details(self, state: State, fitness):
|
||||
"""Visualize the running details of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -30,15 +28,3 @@ class BaseAlgorithm(StatefulBaseClass):
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def member_count(self, state: State):
|
||||
# to analysis the species
|
||||
raise NotImplementedError
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,40 +1,93 @@
|
||||
from tensorneat.common import State
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .. import BaseAlgorithm
|
||||
from .species import *
|
||||
from tensorneat.common import State
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
genome: BaseGenome,
|
||||
pop_size: int,
|
||||
species_size: int = 10,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
species_fitness_func: callable = jnp.max,
|
||||
):
|
||||
self.species = species
|
||||
self.genome = species.genome
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_controller = SpeciesController(
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
# setup state
|
||||
state = self.genome.setup(state)
|
||||
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(k1, self.pop_size)
|
||||
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
state = state.register(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
# initialize species state
|
||||
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
|
||||
|
||||
return state.update(randkey=randkey)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = state.update(generation=state.generation + 1)
|
||||
|
||||
# tell fitness to species controller
|
||||
state, winner, loser, elite_mask = self.species_controller.update_species(
|
||||
state,
|
||||
fitness,
|
||||
)
|
||||
|
||||
# create next population
|
||||
state = self._create_next_generation(state, winner, loser, elite_mask)
|
||||
|
||||
# speciate the next population
|
||||
state = self.species_controller.speciate(state, self.genome.execute_distance)
|
||||
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return self.species.tell(state, fitness)
|
||||
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def restore(self, state, transformed):
|
||||
return self.genome.restore(state, transformed)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
return self.genome.forward(state, transformed, inputs)
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
@@ -43,13 +96,70 @@ class NEAT(BaseAlgorithm):
|
||||
def num_outputs(self):
|
||||
return self.genome.num_outputs
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
return self.species.pop_size
|
||||
def _create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state.member_count
|
||||
# find next node key for mutation
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state.generation
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
member_count = jax.device_get(state.species.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
|
||||
print(
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
)
|
||||
|
||||
@@ -1,54 +1,35 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseSpecies
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
)
|
||||
from tensorneat.genome.utils import (
|
||||
extract_conn_attrs,
|
||||
extract_node_attrs,
|
||||
)
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
"""
|
||||
Core procedures of NEAT algorithm, contains the following steps:
|
||||
1. Update the fitness of each species;
|
||||
2. Decide which species will be stagnation;
|
||||
3. Decide the number of members of each species in the next generation;
|
||||
4. Choice the crossover pair for each species;
|
||||
5. Divided the whole new population into different species;
|
||||
|
||||
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
|
||||
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
class SpeciesController(StatefulBaseClass):
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
pop_size,
|
||||
species_size,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
|
||||
self.compatibility_disjoint = compatibility_disjoint
|
||||
self.compatibility_weight = compatibility_weight
|
||||
self.species_arange = np.arange(self.species_size)
|
||||
self.max_stagnation = max_stagnation
|
||||
self.species_elitism = species_elitism
|
||||
self.spawn_number_change_rate = spawn_number_change_rate
|
||||
@@ -56,42 +37,33 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.species_fitness_func = species_fitness_func
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
def setup(self, state, first_nodes, first_conns):
|
||||
# the unique index (primary key) for each species
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.genome.setup(state)
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
# the best fitness of each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(randkey, self.pop_size)
|
||||
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
# the last 1 that the species improved
|
||||
last_improved = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
species_keys = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the best fitness of each species
|
||||
last_improved = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the last 1 that the species improved
|
||||
member_count = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the number of members of each species
|
||||
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
|
||||
# the number of members of each species
|
||||
member_count = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the species index of each individual
|
||||
idx2species = jnp.zeros(self.pop_size)
|
||||
|
||||
# nodes for each center genome of each species
|
||||
center_nodes = jnp.full(
|
||||
(self.species_size, self.genome.max_nodes, self.genome.node_gene.length),
|
||||
(self.species_size, *first_nodes.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
# connections for each center genome of each species
|
||||
center_conns = jnp.full(
|
||||
(self.species_size, self.genome.max_conns, self.genome.conn_gene.length),
|
||||
(self.species_size, *first_conns.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
@@ -99,16 +71,10 @@ class DefaultSpecies(BaseSpecies):
|
||||
best_fitness = best_fitness.at[0].set(-jnp.inf)
|
||||
last_improved = last_improved.at[0].set(0)
|
||||
member_count = member_count.at[0].set(self.pop_size)
|
||||
center_nodes = center_nodes.at[0].set(pop_nodes[0])
|
||||
center_conns = center_conns.at[0].set(pop_conns[0])
|
||||
center_nodes = center_nodes.at[0].set(first_nodes)
|
||||
center_conns = center_conns.at[0].set(first_conns)
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
state = state.update(randkey=randkey)
|
||||
|
||||
return state.register(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_state = State(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
@@ -117,53 +83,50 @@ class DefaultSpecies(BaseSpecies):
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=jnp.float32(1), # 0 is reserved for the first species
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
state, winner, loser, elite_mask = self.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.speciate(state)
|
||||
|
||||
return state
|
||||
return state.register(species=species_state)
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
species_state = state.species
|
||||
|
||||
# update the fitness of each species
|
||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||
species_fitness = self._update_species_fitness(species_state, fitness)
|
||||
|
||||
# stagnation species
|
||||
state, species_fitness = self.stagnation(state, species_fitness)
|
||||
species_state, species_fitness = self._stagnation(
|
||||
species_state, species_fitness, state.generation
|
||||
)
|
||||
|
||||
# sort species_info by their fitness. (also push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
|
||||
|
||||
state = state.update(
|
||||
species_keys=state.species_keys[sort_indices],
|
||||
best_fitness=state.best_fitness[sort_indices],
|
||||
last_improved=state.last_improved[sort_indices],
|
||||
member_count=state.member_count[sort_indices],
|
||||
center_nodes=state.center_nodes[sort_indices],
|
||||
center_conns=state.center_conns[sort_indices],
|
||||
species_state = species_state.update(
|
||||
species_keys=species_state.species_keys[sort_indices],
|
||||
best_fitness=species_state.best_fitness[sort_indices],
|
||||
last_improved=species_state.last_improved[sort_indices],
|
||||
member_count=species_state.member_count[sort_indices],
|
||||
center_nodes=species_state.center_nodes[sort_indices],
|
||||
center_conns=species_state.center_conns[sort_indices],
|
||||
)
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
state, spawn_number = self.cal_spawn_numbers(state)
|
||||
spawn_number = self._cal_spawn_numbers(species_state)
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
state, winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, spawn_number, fitness
|
||||
winner, loser, elite_mask = self._create_crossover_pair(
|
||||
species_state, k1, spawn_number, fitness
|
||||
)
|
||||
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
return (
|
||||
state.update(randkey=k2, species=species_state),
|
||||
winner,
|
||||
loser,
|
||||
elite_mask,
|
||||
)
|
||||
|
||||
def update_species_fitness(self, state, fitness):
|
||||
def _update_species_fitness(self, species_state, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
@@ -171,14 +134,16 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = jnp.where(
|
||||
state.idx2species == state.species_keys[idx], fitness, -jnp.inf
|
||||
species_state.idx2species == species_state.species_keys[idx],
|
||||
fitness,
|
||||
-jnp.inf,
|
||||
)
|
||||
val = jnp.max(s_fitness)
|
||||
val = self.species_fitness_func(s_fitness)
|
||||
return val
|
||||
|
||||
return state, jax.vmap(aux_func)(self.species_arange)
|
||||
return vmap(aux_func)(self.species_arange)
|
||||
|
||||
def stagnation(self, state, species_fitness):
|
||||
def _stagnation(self, species_state, species_fitness, generation):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
@@ -187,28 +152,36 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
st = (
|
||||
species_fitness[idx] <= state.best_fitness[idx]
|
||||
) & ( # not better than the best fitness of the species
|
||||
state.generation - state.last_improved[idx] > self.max_stagnation
|
||||
) # for a long time
|
||||
|
||||
# not better than the best fitness of the species
|
||||
# for a long time
|
||||
st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
|
||||
generation - species_state.last_improved[idx] > self.max_stagnation
|
||||
)
|
||||
|
||||
# update last_improved and best_fitness
|
||||
# whether better than the best fitness of the species
|
||||
li, bf = jax.lax.cond(
|
||||
species_fitness[idx] > state.best_fitness[idx],
|
||||
lambda: (state.generation, species_fitness[idx]), # update
|
||||
species_fitness[idx] > species_state.best_fitness[idx],
|
||||
lambda: (generation, species_fitness[idx]), # update
|
||||
lambda: (
|
||||
state.last_improved[idx],
|
||||
state.best_fitness[idx],
|
||||
species_state.last_improved[idx],
|
||||
species_state.best_fitness[idx],
|
||||
), # not update
|
||||
)
|
||||
|
||||
return st, bf, li
|
||||
|
||||
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(
|
||||
spe_st, best_fitness, last_improved = vmap(check_stagnation)(
|
||||
self.species_arange
|
||||
)
|
||||
|
||||
# update species state
|
||||
species_state = species_state.update(
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
)
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
spe_st = jnp.where(
|
||||
@@ -224,18 +197,18 @@ class DefaultSpecies(BaseSpecies):
|
||||
jnp.nan, # best_fitness
|
||||
jnp.nan, # last_improved
|
||||
jnp.nan, # member_count
|
||||
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
|
||||
jnp.full_like(species_state.center_conns[idx], jnp.nan),
|
||||
-jnp.inf, # species_fitness
|
||||
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
|
||||
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
|
||||
), # stagnation species
|
||||
lambda: (
|
||||
state.species_keys[idx],
|
||||
best_fitness[idx],
|
||||
last_improved[idx],
|
||||
state.member_count[idx],
|
||||
species_state.species_keys[idx],
|
||||
species_state.best_fitness[idx],
|
||||
species_state.last_improved[idx],
|
||||
species_state.member_count[idx],
|
||||
species_state.center_nodes[idx],
|
||||
species_state.center_conns[idx],
|
||||
species_fitness[idx],
|
||||
state.center_nodes[idx],
|
||||
state.center_conns[idx],
|
||||
), # not stagnation species
|
||||
)
|
||||
|
||||
@@ -244,13 +217,13 @@ class DefaultSpecies(BaseSpecies):
|
||||
best_fitness,
|
||||
last_improved,
|
||||
member_count,
|
||||
species_fitness,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
) = jax.vmap(update_func)(self.species_arange)
|
||||
species_fitness,
|
||||
) = vmap(update_func)(self.species_arange)
|
||||
|
||||
return (
|
||||
state.update(
|
||||
species_state.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
@@ -261,7 +234,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
species_fitness,
|
||||
)
|
||||
|
||||
def cal_spawn_numbers(self, state):
|
||||
def _cal_spawn_numbers(self, species_state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
@@ -269,7 +242,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
species_keys = state.species_keys
|
||||
species_keys = species_state.species_keys
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_keys)
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
@@ -288,7 +261,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers for a species
|
||||
previous_size = state.member_count
|
||||
previous_size = species_state.member_count
|
||||
spawn_number = (
|
||||
previous_size
|
||||
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
@@ -301,14 +274,17 @@ class DefaultSpecies(BaseSpecies):
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
|
||||
return state, spawn_number
|
||||
return spawn_number
|
||||
|
||||
def create_crossover_pair(self, state, spawn_number, fitness):
|
||||
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
||||
s_idx = self.species_arange
|
||||
p_idx = jnp.arange(self.pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
members = state.idx2species == state.species_keys[idx]
|
||||
# choose parents from the in the same species
|
||||
# key -> randkey, idx -> the idx of current species
|
||||
|
||||
members = species_state.idx2species == species_state.species_keys[idx]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
||||
@@ -333,11 +309,16 @@ class DefaultSpecies(BaseSpecies):
|
||||
elite = jnp.where(p_idx < self.genome_elitism, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
randkey_, randkey = jax.random.split(state.randkey)
|
||||
fas, mas, elites = jax.vmap(aux_func)(
|
||||
jax.random.split(randkey_, self.species_size), s_idx
|
||||
# choose parents to crossover in each species
|
||||
# fas, mas, elites: (self.species_size, self.pop_size)
|
||||
# fas -> father indices, mas -> mother indices, elites -> whether elite or not
|
||||
fas, mas, elites = vmap(aux_func)(
|
||||
jax.random.split(randkey, self.species_size), s_idx
|
||||
)
|
||||
|
||||
# merge choosen parents from each species into one array
|
||||
# winner, loser, elite_mask: (self.pop_size)
|
||||
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
@@ -351,18 +332,18 @@ class DefaultSpecies(BaseSpecies):
|
||||
elites[loc, idx_in_species],
|
||||
)
|
||||
|
||||
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return state.update(randkey=randkey), winner, loser, elite_mask
|
||||
return winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state):
|
||||
def speciate(self, state, genome_distance_func: Callable):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = jax.vmap(
|
||||
self.distance, in_axes=(None, None, None, 0, 0)
|
||||
o2p_distance_func = vmap(
|
||||
genome_distance_func, in_axes=(None, None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
@@ -379,7 +360,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
return (i < self.species_size) & (
|
||||
~jnp.isnan(state.species_keys[i])
|
||||
~jnp.isnan(state.species.species_keys[i])
|
||||
) # current species is existing
|
||||
|
||||
def body_func(carry):
|
||||
@@ -392,7 +373,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species_keys[i])
|
||||
i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
|
||||
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||
|
||||
@@ -404,13 +385,21 @@ class DefaultSpecies(BaseSpecies):
|
||||
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances),
|
||||
(
|
||||
0,
|
||||
idx2species,
|
||||
state.species.center_nodes,
|
||||
state.species.center_conns,
|
||||
o2c_distances,
|
||||
),
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
species=state.species.update(
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
)
|
||||
|
||||
# part 2: assign members to each species
|
||||
@@ -500,12 +489,12 @@ class DefaultSpecies(BaseSpecies):
|
||||
body_func,
|
||||
(
|
||||
0,
|
||||
state.idx2species,
|
||||
state.species.idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
state.species_keys,
|
||||
state.species.species_keys,
|
||||
o2c_distances,
|
||||
state.next_species_key,
|
||||
state.species.next_species_key,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -514,10 +503,10 @@ class DefaultSpecies(BaseSpecies):
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
||||
last_improved = jnp.where(
|
||||
new_created_mask, state.generation, state.last_improved
|
||||
new_created_mask, state.generation, state.species.last_improved
|
||||
)
|
||||
|
||||
# update members count
|
||||
@@ -530,9 +519,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
), # count members
|
||||
)
|
||||
|
||||
member_count = jax.vmap(count_members)(self.species_arange)
|
||||
member_count = vmap(count_members)(self.species_arange)
|
||||
|
||||
return state.update(
|
||||
species_state = state.species.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
@@ -543,135 +532,6 @@ class DefaultSpecies(BaseSpecies):
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
def distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = jax.vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = jax.vmap(extract_node_attrs)(sr)
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((conns1, conns2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = jax.vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = jax.vmap(extract_conn_attrs)(sr)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
# find next node key
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species=species_state,
|
||||
)
|
||||
@@ -1,2 +0,0 @@
|
||||
from .base import BaseSpecies
|
||||
from .default import DefaultSpecies
|
||||
@@ -1,20 +0,0 @@
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class BaseSpecies(StatefulBaseClass):
|
||||
genome: BaseGenome
|
||||
pop_size: int
|
||||
species_size: int
|
||||
|
||||
def ask(self, state: State):
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
def speciate(self, state):
|
||||
raise NotImplementedError
|
||||
@@ -1,3 +1,5 @@
|
||||
from .gene import *
|
||||
from .operations import *
|
||||
from .base import BaseGenome
|
||||
from .default import DefaultGenome
|
||||
from .recurrent import RecurrentGenome
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union, Sequence, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
@@ -34,14 +34,20 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
response_mutate_power: float = 0.5,
|
||||
response_mutate_rate: float = 0.7,
|
||||
response_replace_rate: float = 0.1,
|
||||
aggregation_default: callable = Agg.sum,
|
||||
aggregation_options: Tuple = (Agg.sum,),
|
||||
aggregation_default: Callable = Agg.sum,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: callable = Act.sigmoid,
|
||||
activation_options: Tuple = (Act.sigmoid,),
|
||||
activation_default: Callable = Act.sigmoid,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(aggregation_options, Callable):
|
||||
aggregation_options = [aggregation_options]
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
self.bias_init_std = bias_init_std
|
||||
self.bias_mutate_power = bias_mutate_power
|
||||
|
||||
@@ -13,7 +13,7 @@ class DefaultDistance(BaseDistance):
|
||||
self.compatibility_disjoint = compatibility_disjoint
|
||||
self.compatibility_weight = compatibility_weight
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
|
||||
@@ -8,5 +8,5 @@ class BaseMutation(StatefulBaseClass):
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def __call__(self, state, randkey, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -33,17 +33,17 @@ class DefaultMutation(BaseMutation):
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def __call__(self, state, randkey, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, k1, genome, nodes, conns, new_node_key
|
||||
state, k1, nodes, conns, new_node_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
|
||||
nodes, conns = self.mutate_values(state, k2, nodes, conns)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def mutate_structure(self, state, randkey, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
"""
|
||||
add a node while do not influence the output of the network
|
||||
@@ -62,7 +62,7 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
# add a new node with identity attrs
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state)
|
||||
nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
@@ -71,7 +71,7 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
genome.conn_gene.new_identity_attrs(state),
|
||||
self.genome.conn_gene.new_identity_attrs(state),
|
||||
)
|
||||
# second is with the origin attrs
|
||||
new_conns = add_conn(
|
||||
@@ -97,8 +97,8 @@ class DefaultMutation(BaseMutation):
|
||||
key, idx = self.choose_node_key(
|
||||
key_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=False,
|
||||
)
|
||||
@@ -136,8 +136,8 @@ class DefaultMutation(BaseMutation):
|
||||
i_key, from_idx = self.choose_node_key(
|
||||
k1_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=True,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
@@ -146,8 +146,8 @@ class DefaultMutation(BaseMutation):
|
||||
o_key, to_idx = self.choose_node_key(
|
||||
k2_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
@@ -161,10 +161,10 @@ class DefaultMutation(BaseMutation):
|
||||
def successful():
|
||||
# add a connection with zero attrs
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state)
|
||||
conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
if self.genome.network_type == "feedforward":
|
||||
u_conns = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = u_conns != I_INF
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
@@ -175,7 +175,7 @@ class DefaultMutation(BaseMutation):
|
||||
successful,
|
||||
)
|
||||
|
||||
elif genome.network_type == "recurrent":
|
||||
elif self.genome.network_type == "recurrent":
|
||||
return jax.lax.cond(
|
||||
is_already_exist | (remain_conn_space < 1),
|
||||
nothing,
|
||||
@@ -183,7 +183,7 @@ class DefaultMutation(BaseMutation):
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {genome.network_type}")
|
||||
raise ValueError(f"Invalid network type: {self.genome.network_type}")
|
||||
|
||||
def mutate_delete_conn(key_, nodes_, conns_):
|
||||
# randomly choose a connection
|
||||
@@ -223,19 +223,19 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, state, randkey, genome, nodes, conns):
|
||||
def mutate_values(self, state, randkey, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
|
||||
nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=self.genome.max_conns)
|
||||
|
||||
node_attrs = vmap(extract_node_attrs)(nodes)
|
||||
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_randkeys, node_attrs
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
|
||||
conn_attrs = vmap(extract_conn_attrs)(conns)
|
||||
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_randkeys, conn_attrs
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
|
||||
@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
|
||||
import datetime, time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import BaseAlgorithm
|
||||
from problem import BaseProblem
|
||||
from problem.rl_env import RLEnv
|
||||
from problem.func_fit import FuncFit
|
||||
from tensorneat.algorithm import BaseAlgorithm
|
||||
from tensorneat.problem import BaseProblem
|
||||
from tensorneat.problem.rl_env import RLEnv
|
||||
from tensorneat.problem.func_fit import FuncFit
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class Pipeline(StatefulBaseClass):
|
||||
print("Fitness limit reached!")
|
||||
break
|
||||
|
||||
if self.algorithm.generation(state) >= self.generation_limit:
|
||||
if int(state.generation) >= self.generation_limit:
|
||||
print("Generation limit reached!")
|
||||
|
||||
if self.is_save:
|
||||
@@ -203,6 +203,8 @@ class Pipeline(StatefulBaseClass):
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
generation = int(state.generation)
|
||||
|
||||
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
|
||||
|
||||
@@ -223,8 +225,12 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
@@ -232,42 +238,18 @@ class Pipeline(StatefulBaseClass):
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop = jax.device_get(pop)
|
||||
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
# append log
|
||||
if self.is_save:
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
self.algorithm.show_details(state, fitnesses)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
|
||||
Reference in New Issue
Block a user