modify NEAT package; successfully run xor example

This commit is contained in:
root
2024-07-11 10:10:16 +08:00
parent 52d5f046d3
commit 4a631f9464
14 changed files with 420 additions and 502 deletions

View File

View File

@@ -1,43 +1,38 @@
from pipeline import Pipeline from tensorneat.pipeline import Pipeline
from algorithm.neat import * from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, DefaultNodeGene, DefaultMutation
from problem.func_fit import XOR3d from tensorneat.problem.func_fit import XOR3d
from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg from tensorneat.common import Act, Agg
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( pop_size=10000,
genome=DenseInitialize( species_size=20,
num_inputs=3, compatibility_threshold=2,
num_outputs=1, survival_threshold=0.01,
max_nodes=50, genome=DefaultGenome(
max_conns=100, num_inputs=3,
node_gene=DefaultNodeGene( num_outputs=1,
activation_default=Act.tanh, init_hidden_layers=(),
# activation_options=(Act.tanh,), node_gene=DefaultNodeGene(
activation_options=ACT_ALL, activation_default=Act.tanh,
aggregation_default=Agg.sum, activation_options=Act.tanh,
# aggregation_options=(Agg.sum,), aggregation_default=Agg.sum,
aggregation_options=AGG_ALL, aggregation_options=Agg.sum,
), ),
output_transform=Act.standard_sigmoid, # the activation function for output node output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.1, node_add=0.1,
conn_add=0.1, conn_add=0.1,
node_delete=0, node_delete=0,
conn_delete=0, conn_delete=0,
),
), ),
pop_size=10000,
species_size=20,
compatibility_threshold=2,
survival_threshold=0.01, # magic
), ),
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=500,
fitness_target=-1e-3, fitness_target=-1e-8,
) )
# initialize state # initialize state
@@ -47,4 +42,3 @@ if __name__ == "__main__":
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result # show result
pipeline.show(state, best) pipeline.show(state, best)
pipeline.save(state=state)

View File

@@ -6,7 +6,7 @@
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"> <rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work> <cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/> <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2024-07-10T16:50:19.947855</dc:date> <dc:date>2024-07-10T19:47:34.359228</dc:date>
<dc:format>image/svg+xml</dc:format> <dc:format>image/svg+xml</dc:format>
<dc:creator> <dc:creator>
<cc:Agent> <cc:Agent>
@@ -32,222 +32,222 @@ z
<g id="patch_2"> <g id="patch_2">
<path d="M 44.79098 308.403612 <path d="M 44.79098 308.403612
Q 87.590594 244.204191 129.770035 180.93503 Q 87.590594 244.204191 129.770035 180.93503
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.887134 183.153831 <path d="M 125.887134 183.153831
L 129.770035 180.93503 L 129.770035 180.93503
L 129.215335 185.372632 L 129.215335 185.372632
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_3"> <g id="patch_3">
<path d="M 46.916335 239.00779 <path d="M 46.916335 239.00779
Q 87.591722 208.50125 127.372682 178.665529 Q 87.591722 208.50125 127.372682 178.665529
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 122.972682 179.465529 <path d="M 122.972682 179.465529
L 127.372682 178.665529 L 127.372682 178.665529
L 125.372682 182.665529 L 125.372682 182.665529
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_4"> <g id="patch_4">
<path d="M 48.647998 172.8 <path d="M 48.647998 172.8
Q 87.590519 172.8 125.415005 172.8 Q 87.590519 172.8 125.415005 172.8
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 121.415005 170.8 <path d="M 121.415005 170.8
L 125.415005 172.8 L 125.415005 172.8
L 121.415005 174.8 L 121.415005 174.8
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_5"> <g id="patch_5">
<path d="M 46.916335 106.59221 <path d="M 46.916335 106.59221
Q 87.591722 137.09875 127.372682 166.934471 Q 87.591722 137.09875 127.372682 166.934471
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.372682 162.934471 <path d="M 125.372682 162.934471
L 127.372682 166.934471 L 127.372682 166.934471
L 122.972682 166.134471 L 122.972682 166.134471
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_6"> <g id="patch_6">
<path d="M 44.79098 37.196388 <path d="M 44.79098 37.196388
Q 87.590594 101.395809 129.770035 164.66497 Q 87.590594 101.395809 129.770035 164.66497
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 129.215335 160.227368 <path d="M 129.215335 160.227368
L 129.770035 164.66497 L 129.770035 164.66497
L 125.887134 162.446169 L 125.887134 162.446169
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_7"> <g id="patch_7">
<path d="M 143.30257 175.840943 <path d="M 143.30257 175.840943
Q 182.796502 190.651168 221.243586 205.068824 Q 182.796502 190.651168 221.243586 205.068824
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 218.200516 201.791672 <path d="M 218.200516 201.791672
L 221.243586 205.068824 L 221.243586 205.068824
L 216.796023 205.536989 L 216.796023 205.536989
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_8"> <g id="patch_8">
<path d="M 143.30257 169.759057 <path d="M 143.30257 169.759057
Q 182.796502 154.948832 221.243586 140.531176 Q 182.796502 154.948832 221.243586 140.531176
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 216.796023 140.063011 <path d="M 216.796023 140.063011
L 221.243586 140.531176 L 221.243586 140.531176
L 218.200516 143.808328 L 218.200516 143.808328
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_9"> <g id="patch_9">
<path d="M 238.509181 211.543422 <path d="M 238.509181 211.543422
Q 278.003113 226.353647 316.450198 240.771303 Q 278.003113 226.353647 316.450198 240.771303
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 237.494151 <path d="M 313.407128 237.494151
L 316.450198 240.771303 L 316.450198 240.771303
L 312.002634 241.239468 L 312.002634 241.239468
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_10"> <g id="patch_10">
<path d="M 238.509181 205.461536 <path d="M 238.509181 205.461536
Q 278.003113 190.651312 316.450198 176.233655 Q 278.003113 190.651312 316.450198 176.233655
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 175.765491 <path d="M 312.002634 175.765491
L 316.450198 176.233655 L 316.450198 176.233655
L 313.407128 179.510807 L 313.407128 179.510807
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_11"> <g id="patch_11">
<path d="M 236.155746 202.027265 <path d="M 236.155746 202.027265
Q 278.00531 154.946506 319.112092 108.701376 Q 278.00531 154.946506 319.112092 108.701376
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 314.959818 110.362286 <path d="M 314.959818 110.362286
L 319.112092 108.701376 L 319.112092 108.701376
L 317.949455 113.019741 L 317.949455 113.019741
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_12"> <g id="patch_12">
<path d="M 236.155746 143.572735 <path d="M 236.155746 143.572735
Q 278.00531 190.653494 319.112092 236.898624 Q 278.00531 190.653494 319.112092 236.898624
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 317.949455 232.580259 <path d="M 317.949455 232.580259
L 319.112092 236.898624 L 319.112092 236.898624
L 314.959818 235.237714 L 314.959818 235.237714
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_13"> <g id="patch_13">
<path d="M 238.509181 140.138464 <path d="M 238.509181 140.138464
Q 278.003113 154.948688 316.450198 169.366345 Q 278.003113 154.948688 316.450198 169.366345
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 166.089193 <path d="M 313.407128 166.089193
L 316.450198 169.366345 L 316.450198 169.366345
L 312.002634 169.834509 L 312.002634 169.834509
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_14"> <g id="patch_14">
<path d="M 238.509181 134.056578 <path d="M 238.509181 134.056578
Q 278.003113 119.246353 316.450198 104.828697 Q 278.003113 119.246353 316.450198 104.828697
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 104.360532 <path d="M 312.002634 104.360532
L 316.450198 104.828697 L 316.450198 104.828697
L 313.407128 108.105849 L 313.407128 108.105849
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_15"> <g id="patch_15">
<path d="M 334.267833 244.204959 <path d="M 334.267833 244.204959
Q 373.210353 244.204959 411.03484 244.204959 Q 373.210353 244.204959 411.03484 244.204959
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 242.204959 <path d="M 407.03484 242.204959
L 411.03484 244.204959 L 411.03484 244.204959
L 407.03484 246.204959 L 407.03484 246.204959
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_16"> <g id="patch_16">
<path d="M 332.53617 239.00779 <path d="M 332.53617 239.00779
Q 373.211557 208.50125 412.992517 178.665529 Q 373.211557 208.50125 412.992517 178.665529
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 179.465529 <path d="M 408.592517 179.465529
L 412.992517 178.665529 L 412.992517 178.665529
L 410.992517 182.665529 L 410.992517 182.665529
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_17"> <g id="patch_17">
<path d="M 330.410815 236.998654 <path d="M 330.410815 236.998654
Q 373.210429 172.799232 415.38987 109.530071 Q 373.210429 172.799232 415.38987 109.530071
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 411.506968 111.748872 <path d="M 411.506968 111.748872
L 415.38987 109.530071 L 415.38987 109.530071
L 414.83517 113.967673 L 414.83517 113.967673
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_18"> <g id="patch_18">
<path d="M 332.53617 177.997169 <path d="M 332.53617 177.997169
Q 373.211557 208.503709 412.992517 238.339429 Q 373.211557 208.503709 412.992517 238.339429
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 234.339429 <path d="M 410.992517 234.339429
L 412.992517 238.339429 L 412.992517 238.339429
L 408.592517 237.539429 L 408.592517 237.539429
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_19"> <g id="patch_19">
<path d="M 334.267833 172.8 <path d="M 334.267833 172.8
Q 373.210353 172.8 411.03484 172.8 Q 373.210353 172.8 411.03484 172.8
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 170.8 <path d="M 407.03484 170.8
L 411.03484 172.8 L 411.03484 172.8
L 407.03484 174.8 L 407.03484 174.8
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_20"> <g id="patch_20">
<path d="M 332.53617 167.602831 <path d="M 332.53617 167.602831
Q 373.211557 137.096291 412.992517 107.260571 Q 373.211557 137.096291 412.992517 107.260571
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 108.060571 <path d="M 408.592517 108.060571
L 412.992517 107.260571 L 412.992517 107.260571
L 410.992517 111.260571 L 410.992517 111.260571
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_21"> <g id="patch_21">
<path d="M 330.410815 108.601346 <path d="M 330.410815 108.601346
Q 373.210429 172.800768 415.38987 236.069929 Q 373.210429 172.800768 415.38987 236.069929
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 414.83517 231.632327 <path d="M 414.83517 231.632327
L 415.38987 236.069929 L 415.38987 236.069929
L 411.506968 233.851128 L 411.506968 233.851128
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_22"> <g id="patch_22">
<path d="M 332.53617 106.59221 <path d="M 332.53617 106.59221
Q 373.211557 137.09875 412.992517 166.934471 Q 373.211557 137.09875 412.992517 166.934471
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 162.934471 <path d="M 410.992517 162.934471
L 412.992517 166.934471 L 412.992517 166.934471
L 408.592517 166.134471 L 408.592517 166.134471
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="patch_23"> <g id="patch_23">
<path d="M 334.267833 101.395041 <path d="M 334.267833 101.395041
Q 373.210353 101.395041 411.03484 101.395041 Q 373.210353 101.395041 411.03484 101.395041
" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 99.395041 <path d="M 407.03484 99.395041
L 411.03484 101.395041 L 411.03484 101.395041
L 407.03484 103.395041 L 407.03484 103.395041
z z
" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> " clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g> </g>
<g id="PathCollection_1"> <g id="PathCollection_1">
<path d="M 39.986777 324.270171 <path d="M 39.986777 324.270171
@@ -260,7 +260,7 @@ C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642 C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171 C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 252.865213 <path d="M 39.986777 252.865213
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683 C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959 C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
@@ -271,7 +271,7 @@ C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683 C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213 C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 181.460254 <path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724 C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8 C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
@@ -282,7 +282,7 @@ C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724 C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254 C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 110.055295 <path d="M 39.986777 110.055295
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766 C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041 C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
@@ -293,7 +293,7 @@ C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766 C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295 C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 38.650337 <path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807 C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083 C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
@@ -304,7 +304,7 @@ C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807 C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337 C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 135.193388 181.460254 <path d="M 135.193388 181.460254
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724 C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8 C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
@@ -315,7 +315,7 @@ C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724 C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254 C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 217.162733 <path d="M 230.4 217.162733
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204 C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479 C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
@@ -326,7 +326,7 @@ C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204 C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733 C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 145.757775 <path d="M 230.4 145.757775
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245 C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521 C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
@@ -337,7 +337,7 @@ C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245 C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775 C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 252.865213 <path d="M 325.606612 252.865213
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683 C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959 C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
@@ -348,7 +348,7 @@ C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683 C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213 C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 181.460254 <path d="M 325.606612 181.460254
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724 C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8 C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
@@ -359,7 +359,7 @@ C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724 C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254 C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 110.055295 <path d="M 325.606612 110.055295
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766 C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041 C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
@@ -370,7 +370,7 @@ C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766 C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295 C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 252.865213 <path d="M 420.813223 252.865213
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683 C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959 C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
@@ -381,7 +381,7 @@ C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683 C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213 C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 181.460254 <path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724 C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8 C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
@@ -392,7 +392,7 @@ C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724 C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254 C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 110.055295 <path d="M 420.813223 110.055295
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766 C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041 C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
@@ -403,12 +403,12 @@ C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766 C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295 C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
z z
" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> " clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
</g> </g>
</g> </g>
</g> </g>
<defs> <defs>
<clipPath id="p8fe09283f8"> <clipPath id="p572566e0dc">
<rect x="0" y="0" width="460.8" height="345.6"/> <rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath> </clipPath>
</defs> </defs>

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -14,13 +14,11 @@ class BaseAlgorithm(StatefulBaseClass):
"""transform the genome into a neural network""" """transform the genome into a neural network"""
raise NotImplementedError raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, transformed, inputs): def forward(self, state, transformed, inputs):
raise NotImplementedError raise NotImplementedError
def update_by_batch(self, state, batch_input, transformed): def show_details(self, state: State, fitness):
"""Visualize the running details of the algorithm"""
raise NotImplementedError raise NotImplementedError
@property @property
@@ -30,15 +28,3 @@ class BaseAlgorithm(StatefulBaseClass):
@property @property
def num_outputs(self): def num_outputs(self):
raise NotImplementedError raise NotImplementedError
@property
def pop_size(self):
raise NotImplementedError
def member_count(self, state: State):
# to analysis the species
raise NotImplementedError
def generation(self, state: State):
# to analysis the algorithm
raise NotImplementedError

View File

@@ -1,40 +1,93 @@
from tensorneat.common import State import jax
from jax import vmap, numpy as jnp
import numpy as np
from .species import SpeciesController
from .. import BaseAlgorithm from .. import BaseAlgorithm
from .species import * from tensorneat.common import State
from tensorneat.genome import BaseGenome
class NEAT(BaseAlgorithm): class NEAT(BaseAlgorithm):
def __init__( def __init__(
self, self,
species: BaseSpecies, genome: BaseGenome,
pop_size: int,
species_size: int = 10,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
species_fitness_func: callable = jnp.max,
): ):
self.species = species self.genome = genome
self.genome = species.genome self.pop_size = pop_size
self.species_controller = SpeciesController(
pop_size,
species_size,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
)
def setup(self, state=State()): def setup(self, state=State()):
state = self.species.setup(state) # setup state
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
# initialize the population
initialize_keys = jax.random.split(k1, self.pop_size)
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
state = state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
generation=jnp.float32(0),
)
# initialize species state
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
return state.update(randkey=randkey)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
state = state.update(generation=state.generation + 1)
# tell fitness to species controller
state, winner, loser, elite_mask = self.species_controller.update_species(
state,
fitness,
)
# create next population
state = self._create_next_generation(state, winner, loser, elite_mask)
# speciate the next population
state = self.species_controller.speciate(state, self.genome.execute_distance)
return state return state
def ask(self, state: State):
return self.species.ask(state)
def tell(self, state: State, fitness):
return self.species.tell(state, fitness)
def transform(self, state, individual): def transform(self, state, individual):
"""transform the genome into a neural network"""
nodes, conns = individual nodes, conns = individual
return self.genome.transform(state, nodes, conns) return self.genome.transform(state, nodes, conns)
def restore(self, state, transformed):
return self.genome.restore(state, transformed)
def forward(self, state, transformed, inputs): def forward(self, state, transformed, inputs):
return self.genome.forward(state, transformed, inputs) return self.genome.forward(state, transformed, inputs)
def update_by_batch(self, state, batch_input, transformed):
return self.genome.update_by_batch(state, batch_input, transformed)
@property @property
def num_inputs(self): def num_inputs(self):
return self.genome.num_inputs return self.genome.num_inputs
@@ -43,13 +96,70 @@ class NEAT(BaseAlgorithm):
def num_outputs(self): def num_outputs(self):
return self.genome.num_outputs return self.genome.num_outputs
@property def _create_next_generation(self, state, winner, loser, elite_mask):
def pop_size(self):
return self.species.pop_size
def member_count(self, state: State): # find next node key for mutation
return state.member_count all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
def generation(self, state: State): # prepare random keys
# to analysis the algorithm k1, k2, randkey = jax.random.split(state.randkey, 3)
return state.generation crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)
def show_details(self, state, fitness):
member_count = jax.device_get(state.species.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print(
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
)

View File

@@ -1,54 +1,35 @@
import jax, jax.numpy as jnp from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
from .base import BaseSpecies
from tensorneat.common import ( from tensorneat.common import (
State, State,
StatefulBaseClass,
rank_elements, rank_elements,
argmin_with_mask, argmin_with_mask,
fetch_first, fetch_first,
) )
from tensorneat.genome.utils import (
extract_conn_attrs,
extract_node_attrs,
)
from tensorneat.genome import BaseGenome
""" class SpeciesController(StatefulBaseClass):
Core procedures of NEAT algorithm, contains the following steps:
1. Update the fitness of each species;
2. Decide which species will be stagnation;
3. Decide the number of members of each species in the next generation;
4. Choice the crossover pair for each species;
5. Divided the whole new population into different species;
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
"""
class DefaultSpecies(BaseSpecies):
def __init__( def __init__(
self, self,
genome: BaseGenome,
pop_size, pop_size,
species_size, species_size,
compatibility_disjoint: float = 1.0, max_stagnation,
compatibility_weight: float = 0.4, species_elitism,
max_stagnation: int = 15, spawn_number_change_rate,
species_elitism: int = 2, genome_elitism,
spawn_number_change_rate: float = 0.5, survival_threshold,
genome_elitism: int = 2, min_species_size,
survival_threshold: float = 0.2, compatibility_threshold,
min_species_size: int = 1, species_fitness_func,
compatibility_threshold: float = 3.0,
): ):
self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size
self.species_size = species_size self.species_size = species_size
self.species_arange = np.arange(self.species_size)
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.max_stagnation = max_stagnation self.max_stagnation = max_stagnation
self.species_elitism = species_elitism self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate self.spawn_number_change_rate = spawn_number_change_rate
@@ -56,42 +37,33 @@ class DefaultSpecies(BaseSpecies):
self.survival_threshold = survival_threshold self.survival_threshold = survival_threshold
self.min_species_size = min_species_size self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold self.compatibility_threshold = compatibility_threshold
self.species_fitness_func = species_fitness_func
self.species_arange = jnp.arange(self.species_size) def setup(self, state, first_nodes, first_conns):
# the unique index (primary key) for each species
species_keys = jnp.full((self.species_size,), jnp.nan)
def setup(self, state=State()): # the best fitness of each species
state = self.genome.setup(state) best_fitness = jnp.full((self.species_size,), jnp.nan)
k1, randkey = jax.random.split(state.randkey, 2)
# initialize the population # the last 1 that the species improved
initialize_keys = jax.random.split(randkey, self.pop_size) last_improved = jnp.full((self.species_size,), jnp.nan)
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
species_keys = jnp.full( # the number of members of each species
(self.species_size,), jnp.nan member_count = jnp.full((self.species_size,), jnp.nan)
) # the unique index (primary key) for each species
best_fitness = jnp.full( # the species index of each individual
(self.species_size,), jnp.nan idx2species = jnp.zeros(self.pop_size)
) # the best fitness of each species
last_improved = jnp.full(
(self.species_size,), jnp.nan
) # the last 1 that the species improved
member_count = jnp.full(
(self.species_size,), jnp.nan
) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# nodes for each center genome of each species # nodes for each center genome of each species
center_nodes = jnp.full( center_nodes = jnp.full(
(self.species_size, self.genome.max_nodes, self.genome.node_gene.length), (self.species_size, *first_nodes.shape),
jnp.nan, jnp.nan,
) )
# connections for each center genome of each species # connections for each center genome of each species
center_conns = jnp.full( center_conns = jnp.full(
(self.species_size, self.genome.max_conns, self.genome.conn_gene.length), (self.species_size, *first_conns.shape),
jnp.nan, jnp.nan,
) )
@@ -99,16 +71,10 @@ class DefaultSpecies(BaseSpecies):
best_fitness = best_fitness.at[0].set(-jnp.inf) best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0) last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size) member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0]) center_nodes = center_nodes.at[0].set(first_nodes)
center_conns = center_conns.at[0].set(pop_conns[0]) center_conns = center_conns.at[0].set(first_conns)
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) species_state = State(
state = state.update(randkey=randkey)
return state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys, species_keys=species_keys,
best_fitness=best_fitness, best_fitness=best_fitness,
last_improved=last_improved, last_improved=last_improved,
@@ -117,53 +83,50 @@ class DefaultSpecies(BaseSpecies):
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
next_species_key=jnp.float32(1), # 0 is reserved for the first species next_species_key=jnp.float32(1), # 0 is reserved for the first species
generation=jnp.float32(0),
) )
def ask(self, state): return state.register(species=species_state)
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.speciate(state)
return state
def update_species(self, state, fitness): def update_species(self, state, fitness):
species_state = state.species
# update the fitness of each species # update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness) species_fitness = self._update_species_fitness(species_state, fitness)
# stagnation species # stagnation species
state, species_fitness = self.stagnation(state, species_fitness) species_state, species_fitness = self._stagnation(
species_state, species_fitness, state.generation
)
# sort species_info by their fitness. (also push nan to the end) # sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1] sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
state = state.update( species_state = species_state.update(
species_keys=state.species_keys[sort_indices], species_keys=species_state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices], best_fitness=species_state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices], last_improved=species_state.last_improved[sort_indices],
member_count=state.member_count[sort_indices], member_count=species_state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices], center_nodes=species_state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices], center_conns=species_state.center_conns[sort_indices],
) )
# decide the number of members of each species by their fitness # decide the number of members of each species by their fitness
state, spawn_number = self.cal_spawn_numbers(state) spawn_number = self._cal_spawn_numbers(species_state)
k1, k2 = jax.random.split(state.randkey) k1, k2 = jax.random.split(state.randkey)
# crossover info # crossover info
state, winner, loser, elite_mask = self.create_crossover_pair( winner, loser, elite_mask = self._create_crossover_pair(
state, spawn_number, fitness species_state, k1, spawn_number, fitness
) )
return state.update(randkey=k2), winner, loser, elite_mask return (
state.update(randkey=k2, species=species_state),
winner,
loser,
elite_mask,
)
def update_species_fitness(self, state, fitness): def _update_species_fitness(self, species_state, fitness):
""" """
obtain the fitness of the species by the fitness of each individual. obtain the fitness of the species by the fitness of each individual.
use max criterion. use max criterion.
@@ -171,14 +134,16 @@ class DefaultSpecies(BaseSpecies):
def aux_func(idx): def aux_func(idx):
s_fitness = jnp.where( s_fitness = jnp.where(
state.idx2species == state.species_keys[idx], fitness, -jnp.inf species_state.idx2species == species_state.species_keys[idx],
fitness,
-jnp.inf,
) )
val = jnp.max(s_fitness) val = self.species_fitness_func(s_fitness)
return val return val
return state, jax.vmap(aux_func)(self.species_arange) return vmap(aux_func)(self.species_arange)
def stagnation(self, state, species_fitness): def _stagnation(self, species_state, species_fitness, generation):
""" """
stagnation species. stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
@@ -187,28 +152,36 @@ class DefaultSpecies(BaseSpecies):
def check_stagnation(idx): def check_stagnation(idx):
# determine whether the species stagnation # determine whether the species stagnation
st = (
species_fitness[idx] <= state.best_fitness[idx] # not better than the best fitness of the species
) & ( # not better than the best fitness of the species # for a long time
state.generation - state.last_improved[idx] > self.max_stagnation st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
) # for a long time generation - species_state.last_improved[idx] > self.max_stagnation
)
# update last_improved and best_fitness # update last_improved and best_fitness
# whether better than the best fitness of the species
li, bf = jax.lax.cond( li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx], species_fitness[idx] > species_state.best_fitness[idx],
lambda: (state.generation, species_fitness[idx]), # update lambda: (generation, species_fitness[idx]), # update
lambda: ( lambda: (
state.last_improved[idx], species_state.last_improved[idx],
state.best_fitness[idx], species_state.best_fitness[idx],
), # not update ), # not update
) )
return st, bf, li return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)( spe_st, best_fitness, last_improved = vmap(check_stagnation)(
self.species_arange self.species_arange
) )
# update species state
species_state = species_state.update(
best_fitness=best_fitness,
last_improved=last_improved,
)
# elite species will not be stagnation # elite species will not be stagnation
species_rank = rank_elements(species_fitness) species_rank = rank_elements(species_fitness)
spe_st = jnp.where( spe_st = jnp.where(
@@ -224,18 +197,18 @@ class DefaultSpecies(BaseSpecies):
jnp.nan, # best_fitness jnp.nan, # best_fitness
jnp.nan, # last_improved jnp.nan, # last_improved
jnp.nan, # member_count jnp.nan, # member_count
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
jnp.full_like(species_state.center_conns[idx], jnp.nan),
-jnp.inf, # species_fitness -jnp.inf, # species_fitness
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
), # stagnation species ), # stagnation species
lambda: ( lambda: (
state.species_keys[idx], species_state.species_keys[idx],
best_fitness[idx], species_state.best_fitness[idx],
last_improved[idx], species_state.last_improved[idx],
state.member_count[idx], species_state.member_count[idx],
species_state.center_nodes[idx],
species_state.center_conns[idx],
species_fitness[idx], species_fitness[idx],
state.center_nodes[idx],
state.center_conns[idx],
), # not stagnation species ), # not stagnation species
) )
@@ -244,13 +217,13 @@ class DefaultSpecies(BaseSpecies):
best_fitness, best_fitness,
last_improved, last_improved,
member_count, member_count,
species_fitness,
center_nodes, center_nodes,
center_conns, center_conns,
) = jax.vmap(update_func)(self.species_arange) species_fitness,
) = vmap(update_func)(self.species_arange)
return ( return (
state.update( species_state.update(
species_keys=species_keys, species_keys=species_keys,
best_fitness=best_fitness, best_fitness=best_fitness,
last_improved=last_improved, last_improved=last_improved,
@@ -261,7 +234,7 @@ class DefaultSpecies(BaseSpecies):
species_fitness, species_fitness,
) )
def cal_spawn_numbers(self, state): def _cal_spawn_numbers(self, species_state):
""" """
decide the number of members of each species by their fitness rank. decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members the species with higher fitness will have more members
@@ -269,7 +242,7 @@ class DefaultSpecies(BaseSpecies):
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
""" """
species_keys = state.species_keys species_keys = species_state.species_keys
is_species_valid = ~jnp.isnan(species_keys) is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid) valid_species_num = jnp.sum(is_species_valid)
@@ -288,7 +261,7 @@ class DefaultSpecies(BaseSpecies):
) # calculate member ) # calculate member
# Avoid too much variation of numbers for a species # Avoid too much variation of numbers for a species
previous_size = state.member_count previous_size = species_state.member_count
spawn_number = ( spawn_number = (
previous_size previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate + (target_spawn_number - previous_size) * self.spawn_number_change_rate
@@ -301,14 +274,17 @@ class DefaultSpecies(BaseSpecies):
# add error to the first species to control the sum of spawn_number # add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error) spawn_number = spawn_number.at[0].add(error)
return state, spawn_number return spawn_number
def create_crossover_pair(self, state, spawn_number, fitness): def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
s_idx = self.species_arange s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size) p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx): def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx] # choose parents from the in the same species
# key -> randkey, idx -> the idx of current species
members = species_state.idx2species == species_state.species_keys[idx]
members_num = jnp.sum(members) members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf) members_fitness = jnp.where(members, fitness, -jnp.inf)
@@ -333,11 +309,16 @@ class DefaultSpecies(BaseSpecies):
elite = jnp.where(p_idx < self.genome_elitism, True, False) elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite return fa, ma, elite
randkey_, randkey = jax.random.split(state.randkey) # choose parents to crossover in each species
fas, mas, elites = jax.vmap(aux_func)( # fas, mas, elites: (self.species_size, self.pop_size)
jax.random.split(randkey_, self.species_size), s_idx # fas -> father indices, mas -> mother indices, elites -> whether elite or not
fas, mas, elites = vmap(aux_func)(
jax.random.split(randkey, self.species_size), s_idx
) )
# merge choosen parents from each species into one array
# winner, loser, elite_mask: (self.pop_size)
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
spawn_number_cum = jnp.cumsum(spawn_number) spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx): def aux_func(idx):
@@ -351,18 +332,18 @@ class DefaultSpecies(BaseSpecies):
elites[loc, idx_in_species], elites[loc, idx_in_species],
) )
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx) part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2] is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2) winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1) loser = jnp.where(is_part1_win, part2, part1)
return state.update(randkey=randkey), winner, loser, elite_mask return winner, loser, elite_mask
def speciate(self, state): def speciate(self, state, genome_distance_func: Callable):
# prepare distance functions # prepare distance functions
o2p_distance_func = jax.vmap( o2p_distance_func = vmap(
self.distance, in_axes=(None, None, None, 0, 0) genome_distance_func, in_axes=(None, None, None, 0, 0)
) # one to population ) # one to population
# idx to specie key # idx to specie key
@@ -379,7 +360,7 @@ class DefaultSpecies(BaseSpecies):
i, i2s, cns, ccs, o2c = carry i, i2s, cns, ccs, o2c = carry
return (i < self.species_size) & ( return (i < self.species_size) & (
~jnp.isnan(state.species_keys[i]) ~jnp.isnan(state.species.species_keys[i])
) # current species is existing ) # current species is existing
def body_func(carry): def body_func(carry):
@@ -392,7 +373,7 @@ class DefaultSpecies(BaseSpecies):
# find the closest one # find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_keys[i]) i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
cns = cns.at[i].set(state.pop_nodes[closest_idx]) cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.at[i].set(state.pop_conns[closest_idx]) ccs = ccs.at[i].set(state.pop_conns[closest_idx])
@@ -404,13 +385,21 @@ class DefaultSpecies(BaseSpecies):
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop( _, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
cond_func, cond_func,
body_func, body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances), (
0,
idx2species,
state.species.center_nodes,
state.species.center_conns,
o2c_distances,
),
) )
state = state.update( state = state.update(
idx2species=idx2species, species=state.species.update(
center_nodes=center_nodes, idx2species=idx2species,
center_conns=center_conns, center_nodes=center_nodes,
center_conns=center_conns,
),
) )
# part 2: assign members to each species # part 2: assign members to each species
@@ -500,12 +489,12 @@ class DefaultSpecies(BaseSpecies):
body_func, body_func,
( (
0, 0,
state.idx2species, state.species.idx2species,
center_nodes, center_nodes,
center_conns, center_conns,
state.species_keys, state.species.species_keys,
o2c_distances, o2c_distances,
state.next_species_key, state.species.next_species_key,
), ),
) )
@@ -514,10 +503,10 @@ class DefaultSpecies(BaseSpecies):
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation # complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
last_improved = jnp.where( last_improved = jnp.where(
new_created_mask, state.generation, state.last_improved new_created_mask, state.generation, state.species.last_improved
) )
# update members count # update members count
@@ -530,9 +519,9 @@ class DefaultSpecies(BaseSpecies):
), # count members ), # count members
) )
member_count = jax.vmap(count_members)(self.species_arange) member_count = vmap(count_members)(self.species_arange)
return state.update( species_state = state.species.update(
species_keys=species_keys, species_keys=species_keys,
best_fitness=best_fitness, best_fitness=best_fitness,
last_improved=last_improved, last_improved=last_improved,
@@ -543,135 +532,6 @@ class DefaultSpecies(BaseSpecies):
next_species_key=next_species_key, next_species_key=next_species_key,
) )
def distance(self, state, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d
def node_distance(self, state, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate(
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
fr_attrs = jax.vmap(extract_node_attrs)(fr)
sr_attrs = jax.vmap(extract_node_attrs)(sr)
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def conn_distance(self, state, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate(
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = jax.vmap(extract_conn_attrs)(fr)
sr_attrs = jax.vmap(extract_conn_attrs)(sr)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def create_next_generation(self, state, winner, loser, elite_mask):
# find next node key
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update( return state.update(
randkey=randkey, species=species_state,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
) )

View File

@@ -1,2 +0,0 @@
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -1,20 +0,0 @@
from tensorneat.common import State, StatefulBaseClass
from tensorneat.genome import BaseGenome
class BaseSpecies(StatefulBaseClass):
genome: BaseGenome
pop_size: int
species_size: int
def ask(self, state: State):
raise NotImplementedError
def tell(self, state: State, fitness):
raise NotImplementedError
def update_species(self, state, fitness):
raise NotImplementedError
def speciate(self, state):
raise NotImplementedError

View File

@@ -1,3 +1,5 @@
from .gene import *
from .operations import *
from .base import BaseGenome from .base import BaseGenome
from .default import DefaultGenome from .default import DefaultGenome
from .recurrent import RecurrentGenome from .recurrent import RecurrentGenome

View File

@@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Union, Sequence, Callable
import numpy as np import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
@@ -34,14 +34,20 @@ class DefaultNodeGene(BaseNodeGene):
response_mutate_power: float = 0.5, response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7, response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1, response_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum, aggregation_default: Callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,), aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_replace_rate: float = 0.1, aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid, activation_default: Callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,), activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_replace_rate: float = 0.1, activation_replace_rate: float = 0.1,
): ):
super().__init__() super().__init__()
if isinstance(aggregation_options, Callable):
aggregation_options = [aggregation_options]
if isinstance(activation_options, Callable):
activation_options = [activation_options]
self.bias_init_mean = bias_init_mean self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power self.bias_mutate_power = bias_mutate_power

View File

@@ -13,7 +13,7 @@ class DefaultDistance(BaseDistance):
self.compatibility_disjoint = compatibility_disjoint self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight self.compatibility_weight = compatibility_weight
def __call__(self, state, nodes1, nodes2, conns1, conns2): def __call__(self, state, nodes1, conns1, nodes2, conns2):
""" """
The distance between two genomes The distance between two genomes
""" """

View File

@@ -8,5 +8,5 @@ class BaseMutation(StatefulBaseClass):
self.genome = genome self.genome = genome
return state return state
def __call__(self, state, randkey, genome, nodes, conns, new_node_key): def __call__(self, state, randkey, nodes, conns, new_node_key):
raise NotImplementedError raise NotImplementedError

View File

@@ -33,17 +33,17 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add self.node_add = node_add
self.node_delete = node_delete self.node_delete = node_delete
def __call__(self, state, randkey, genome, nodes, conns, new_node_key): def __call__(self, state, randkey, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey) k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure( nodes, conns = self.mutate_structure(
state, k1, genome, nodes, conns, new_node_key state, k1, nodes, conns, new_node_key
) )
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns) nodes, conns = self.mutate_values(state, k2, nodes, conns)
return nodes, conns return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): def mutate_structure(self, state, randkey, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_): def mutate_add_node(key_, nodes_, conns_):
""" """
add a node while do not influence the output of the network add a node while do not influence the output of the network
@@ -62,7 +62,7 @@ class DefaultMutation(BaseMutation):
# add a new node with identity attrs # add a new node with identity attrs
new_nodes = add_node( new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state) nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state)
) )
# add two new connections # add two new connections
@@ -71,7 +71,7 @@ class DefaultMutation(BaseMutation):
new_conns, new_conns,
i_key, i_key,
new_node_key, new_node_key,
genome.conn_gene.new_identity_attrs(state), self.genome.conn_gene.new_identity_attrs(state),
) )
# second is with the origin attrs # second is with the origin attrs
new_conns = add_conn( new_conns = add_conn(
@@ -97,8 +97,8 @@ class DefaultMutation(BaseMutation):
key, idx = self.choose_node_key( key, idx = self.choose_node_key(
key_, key_,
nodes_, nodes_,
genome.input_idx, self.genome.input_idx,
genome.output_idx, self.genome.output_idx,
allow_input_keys=False, allow_input_keys=False,
allow_output_keys=False, allow_output_keys=False,
) )
@@ -136,8 +136,8 @@ class DefaultMutation(BaseMutation):
i_key, from_idx = self.choose_node_key( i_key, from_idx = self.choose_node_key(
k1_, k1_,
nodes_, nodes_,
genome.input_idx, self.genome.input_idx,
genome.output_idx, self.genome.output_idx,
allow_input_keys=True, allow_input_keys=True,
allow_output_keys=True, allow_output_keys=True,
) )
@@ -146,8 +146,8 @@ class DefaultMutation(BaseMutation):
o_key, to_idx = self.choose_node_key( o_key, to_idx = self.choose_node_key(
k2_, k2_,
nodes_, nodes_,
genome.input_idx, self.genome.input_idx,
genome.output_idx, self.genome.output_idx,
allow_input_keys=False, allow_input_keys=False,
allow_output_keys=True, allow_output_keys=True,
) )
@@ -161,10 +161,10 @@ class DefaultMutation(BaseMutation):
def successful(): def successful():
# add a connection with zero attrs # add a connection with zero attrs
return nodes_, add_conn( return nodes_, add_conn(
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state) conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state)
) )
if genome.network_type == "feedforward": if self.genome.network_type == "feedforward":
u_conns = unflatten_conns(nodes_, conns_) u_conns = unflatten_conns(nodes_, conns_)
conns_exist = u_conns != I_INF conns_exist = u_conns != I_INF
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
@@ -175,7 +175,7 @@ class DefaultMutation(BaseMutation):
successful, successful,
) )
elif genome.network_type == "recurrent": elif self.genome.network_type == "recurrent":
return jax.lax.cond( return jax.lax.cond(
is_already_exist | (remain_conn_space < 1), is_already_exist | (remain_conn_space < 1),
nothing, nothing,
@@ -183,7 +183,7 @@ class DefaultMutation(BaseMutation):
) )
else: else:
raise ValueError(f"Invalid network type: {genome.network_type}") raise ValueError(f"Invalid network type: {self.genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_): def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection # randomly choose a connection
@@ -223,19 +223,19 @@ class DefaultMutation(BaseMutation):
return nodes, conns return nodes, conns
def mutate_values(self, state, randkey, genome, nodes, conns): def mutate_values(self, state, randkey, nodes, conns):
k1, k2 = jax.random.split(randkey) k1, k2 = jax.random.split(randkey)
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes)
conns_randkeys = jax.random.split(k2, num=genome.max_conns) conns_randkeys = jax.random.split(k2, num=self.genome.max_conns)
node_attrs = vmap(extract_node_attrs)(nodes) node_attrs = vmap(extract_node_attrs)(nodes)
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_randkeys, node_attrs state, nodes_randkeys, node_attrs
) )
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
conn_attrs = vmap(extract_conn_attrs)(conns) conn_attrs = vmap(extract_conn_attrs)(conns)
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_randkeys, conn_attrs state, conns_randkeys, conn_attrs
) )
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)

View File

@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
import datetime, time import datetime, time
import numpy as np import numpy as np
from algorithm import BaseAlgorithm from tensorneat.algorithm import BaseAlgorithm
from problem import BaseProblem from tensorneat.problem import BaseProblem
from problem.rl_env import RLEnv from tensorneat.problem.rl_env import RLEnv
from problem.func_fit import FuncFit from tensorneat.problem.func_fit import FuncFit
from tensorneat.common import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
@@ -187,7 +187,7 @@ class Pipeline(StatefulBaseClass):
print("Fitness limit reached!") print("Fitness limit reached!")
break break
if self.algorithm.generation(state) >= self.generation_limit: if int(state.generation) >= self.generation_limit:
print("Generation limit reached!") print("Generation limit reached!")
if self.is_save: if self.is_save:
@@ -204,6 +204,8 @@ class Pipeline(StatefulBaseClass):
def analysis(self, state, pop, fitnesses): def analysis(self, state, pop, fitnesses):
generation = int(state.generation)
valid_fitnesses = fitnesses[~np.isinf(fitnesses)] valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
max_f, min_f, mean_f, std_f = ( max_f, min_f, mean_f, std_f = (
@@ -223,8 +225,12 @@ class Pipeline(StatefulBaseClass):
self.best_genome = pop[0][max_idx], pop[1][max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save: if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: file_name = os.path.join(
self.genome_dir, f"{generation}.npz"
)
with open(file_name, "wb") as f:
np.savez( np.savez(
f, f,
nodes=best_genome[0], nodes=best_genome[0],
@@ -232,42 +238,18 @@ class Pipeline(StatefulBaseClass):
fitness=self.best_fitness, fitness=self.best_fitness,
) )
# save best if save path is not None # append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
member_count = jax.device_get(self.algorithm.member_count(state)) f.write(
species_sizes = [int(i) for i in member_count if i > 0] f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
pop = jax.device_get(pop)
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print( print(
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n", f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
) )
# append log self.algorithm.show_details(state, fitnesses)
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
def show(self, state, best, *args, **kwargs): def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best) transformed = self.algorithm.transform(state, best)