Add CustomFuncFit into problem; Add related examples

This commit is contained in:
root
2024-07-11 18:32:08 +08:00
parent 3cb5fbf581
commit be6a67d7e2
15 changed files with 241 additions and 437 deletions

View File

@@ -0,0 +1,56 @@
import jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode
from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import Act, Agg
def pagie_polynomial(inputs):
x, y = inputs
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
# important! returns an array, NOT a scalar
return jnp.array([res])
if __name__ == "__main__":
custom_problem = CustomFuncFit(
func=pagie_polynomial,
low_bounds=[-1, -1],
upper_bounds=[1, 1],
method="sample",
num_samples=1000,
)
pipeline = Pipeline(
algorithm=NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
init_hidden_layers=(),
node_gene=BiasNode(
activation_options=[Act.identity, Act.inv, Act.square],
aggregation_options=[Agg.sum, Agg.product],
),
output_transform=Act.identity,
),
),
problem=custom_problem,
generation_limit=100,
fitness_target=-1e-4,
seed=42,
)
# initialize state
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
# pipeline.show(state, best)
print(pipeline.algorithm.genome.repr(state, *best))

View File

@@ -1,16 +1,39 @@
import jax, jax.numpy as jnp
arr = jnp.ones((10, 10))
a = jnp.array([
[1, 2, 3],
[4, 5, 6]
])
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode
from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import Act, Agg
def attach_with_inf(arr, idx):
target_dim = arr.ndim + idx.ndim - 1
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
def pagie_polynomial(inputs):
x, y = inputs
return x + y
if __name__ == "__main__":
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=3,
max_conns=2,
init_hidden_layers=(),
node_gene=BiasNode(
activation_options=[Act.identity],
aggregation_options=[Agg.sum],
),
output_transform=Act.identity,
mutation=DefaultMutation(
node_add=0,
node_delete=0,
conn_add=0.0,
conn_delete=0.0,
)
)
randkey = jax.random.PRNGKey(42)
state = genome.setup()
nodes, conns = genome.initialize(state, randkey)
print(genome)
b = attach_with_inf(arr, a)
print(b)

View File

@@ -1,415 +0,0 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2024-07-10T19:47:34.359228</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 345.6
L 460.8 345.6
L 460.8 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 44.79098 308.403612
Q 87.590594 244.204191 129.770035 180.93503
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.887134 183.153831
L 129.770035 180.93503
L 129.215335 185.372632
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_3">
<path d="M 46.916335 239.00779
Q 87.591722 208.50125 127.372682 178.665529
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 122.972682 179.465529
L 127.372682 178.665529
L 125.372682 182.665529
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_4">
<path d="M 48.647998 172.8
Q 87.590519 172.8 125.415005 172.8
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 121.415005 170.8
L 125.415005 172.8
L 121.415005 174.8
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_5">
<path d="M 46.916335 106.59221
Q 87.591722 137.09875 127.372682 166.934471
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 125.372682 162.934471
L 127.372682 166.934471
L 122.972682 166.134471
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_6">
<path d="M 44.79098 37.196388
Q 87.590594 101.395809 129.770035 164.66497
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 129.215335 160.227368
L 129.770035 164.66497
L 125.887134 162.446169
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_7">
<path d="M 143.30257 175.840943
Q 182.796502 190.651168 221.243586 205.068824
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 218.200516 201.791672
L 221.243586 205.068824
L 216.796023 205.536989
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_8">
<path d="M 143.30257 169.759057
Q 182.796502 154.948832 221.243586 140.531176
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 216.796023 140.063011
L 221.243586 140.531176
L 218.200516 143.808328
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_9">
<path d="M 238.509181 211.543422
Q 278.003113 226.353647 316.450198 240.771303
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 237.494151
L 316.450198 240.771303
L 312.002634 241.239468
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_10">
<path d="M 238.509181 205.461536
Q 278.003113 190.651312 316.450198 176.233655
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 175.765491
L 316.450198 176.233655
L 313.407128 179.510807
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_11">
<path d="M 236.155746 202.027265
Q 278.00531 154.946506 319.112092 108.701376
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 314.959818 110.362286
L 319.112092 108.701376
L 317.949455 113.019741
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_12">
<path d="M 236.155746 143.572735
Q 278.00531 190.653494 319.112092 236.898624
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 317.949455 232.580259
L 319.112092 236.898624
L 314.959818 235.237714
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_13">
<path d="M 238.509181 140.138464
Q 278.003113 154.948688 316.450198 169.366345
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 313.407128 166.089193
L 316.450198 169.366345
L 312.002634 169.834509
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_14">
<path d="M 238.509181 134.056578
Q 278.003113 119.246353 316.450198 104.828697
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 312.002634 104.360532
L 316.450198 104.828697
L 313.407128 108.105849
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_15">
<path d="M 334.267833 244.204959
Q 373.210353 244.204959 411.03484 244.204959
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 242.204959
L 411.03484 244.204959
L 407.03484 246.204959
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_16">
<path d="M 332.53617 239.00779
Q 373.211557 208.50125 412.992517 178.665529
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 179.465529
L 412.992517 178.665529
L 410.992517 182.665529
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_17">
<path d="M 330.410815 236.998654
Q 373.210429 172.799232 415.38987 109.530071
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 411.506968 111.748872
L 415.38987 109.530071
L 414.83517 113.967673
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_18">
<path d="M 332.53617 177.997169
Q 373.211557 208.503709 412.992517 238.339429
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 234.339429
L 412.992517 238.339429
L 408.592517 237.539429
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_19">
<path d="M 334.267833 172.8
Q 373.210353 172.8 411.03484 172.8
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 170.8
L 411.03484 172.8
L 407.03484 174.8
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_20">
<path d="M 332.53617 167.602831
Q 373.211557 137.096291 412.992517 107.260571
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 408.592517 108.060571
L 412.992517 107.260571
L 410.992517 111.260571
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_21">
<path d="M 330.410815 108.601346
Q 373.210429 172.800768 415.38987 236.069929
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 414.83517 231.632327
L 415.38987 236.069929
L 411.506968 233.851128
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_22">
<path d="M 332.53617 106.59221
Q 373.211557 137.09875 412.992517 166.934471
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 410.992517 162.934471
L 412.992517 166.934471
L 408.592517 166.134471
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="patch_23">
<path d="M 334.267833 101.395041
Q 373.210353 101.395041 411.03484 101.395041
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
<path d="M 407.03484 99.395041
L 411.03484 101.395041
L 407.03484 103.395041
z
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
</g>
<g id="PathCollection_1">
<path d="M 39.986777 324.270171
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 252.865213
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
C 48.647031 241.908232 47.734532 239.705265 46.110501 238.081234
C 44.486471 236.457204 42.283503 235.544705 39.986777 235.544705
C 37.690051 235.544705 35.487083 236.457204 33.863053 238.081234
C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 110.055295
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
C 48.647031 99.098315 47.734532 96.895348 46.110501 95.271317
C 44.486471 93.647286 42.283503 92.734787 39.986777 92.734787
C 37.690051 92.734787 35.487083 93.647286 33.863053 95.271317
C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 135.193388 181.460254
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 217.162733
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
C 239.060254 206.205753 238.147755 204.002786 236.523724 202.378755
C 234.899694 200.754724 232.696726 199.842225 230.4 199.842225
C 228.103274 199.842225 225.900306 200.754724 224.276276 202.378755
C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 230.4 145.757775
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
C 239.060254 134.800794 238.147755 132.597827 236.523724 130.973796
C 234.899694 129.349766 232.696726 128.437267 230.4 128.437267
C 228.103274 128.437267 225.900306 129.349766 224.276276 130.973796
C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 252.865213
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
C 334.266866 241.908232 333.354367 239.705265 331.730336 238.081234
C 330.106305 236.457204 327.903338 235.544705 325.606612 235.544705
C 323.309885 235.544705 321.106918 236.457204 319.482887 238.081234
C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 181.460254
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 325.606612 110.055295
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
C 334.266866 99.098315 333.354367 96.895348 331.730336 95.271317
C 330.106305 93.647286 327.903338 92.734787 325.606612 92.734787
C 323.309885 92.734787 321.106918 93.647286 319.482887 95.271317
C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 252.865213
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
C 429.473477 241.908232 428.560978 239.705265 426.936947 238.081234
C 425.312917 236.457204 423.109949 235.544705 420.813223 235.544705
C 418.516497 235.544705 416.313529 236.457204 414.689499 238.081234
C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
<path d="M 420.813223 110.055295
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
C 429.473477 99.098315 428.560978 96.895348 426.936947 95.271317
C 425.312917 93.647286 423.109949 92.734787 420.813223 92.734787
C 418.516497 92.734787 416.313529 93.647286 414.689499 95.271317
C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
z
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
</g>
</g>
</g>
<defs>
<clipPath id="p572566e0dc">
<rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -76,7 +76,7 @@ class HyperNEAT(BaseAlgorithm):
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
), self.substrate.make_conns(query_res)
return self.hyper_genome.transform(state, h_nodes, h_conns)

View File

@@ -6,7 +6,7 @@ class BaseSubstrate(StatefulBaseClass):
def make_nodes(self, query_res):
raise NotImplementedError
def make_conn(self, query_res):
def make_conns(self, query_res):
raise NotImplementedError
@property

View File

@@ -1,5 +1,7 @@
import jax.numpy as jnp
from . import BaseSubstrate
from jax import vmap, numpy as jnp
from .base import BaseSubstrate
from tensorneat.genome.utils import set_conn_attrs
class DefaultSubstrate(BaseSubstrate):
@@ -13,8 +15,9 @@ class DefaultSubstrate(BaseSubstrate):
def make_nodes(self, query_res):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 2:].set(query_res) # change weight
def make_conns(self, query_res):
# change weight of conns
return vmap(set_conn_attrs)(self.conns, query_res)
@property
def query_coors(self):

View File

@@ -31,6 +31,7 @@ name2sympy = {
"maxabs": SympyMaxabs,
"mean": SympyMean,
"clip": SympyClip,
"square": SympySquare,
}

View File

@@ -69,6 +69,10 @@ class Act:
z = jnp.clip(z, -10, 10)
return jnp.exp(z)
@staticmethod
def square(z):
return jnp.pow(z, 2)
@staticmethod
def abs(z):
z = jnp.clip(z, -1, 1)

View File

@@ -184,6 +184,12 @@ class SympyExp(sp.Function):
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
class SympySquare(sp.Function):
@classmethod
def eval(cls, z):
return sp.Pow(z, 2)
class SympyAbs(sp.Function):
@classmethod
def eval(cls, z):

View File

@@ -17,6 +17,8 @@ class DefaultConn(BaseConn):
weight_mutate_power: float = 0.15,
weight_mutate_rate: float = 0.2,
weight_replace_rate: float = 0.015,
weight_lower_bound: float = -5.0,
weight_upper_bound: float = 5.0,
):
super().__init__()
self.weight_init_mean = weight_init_mean
@@ -24,6 +26,9 @@ class DefaultConn(BaseConn):
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
self.weight_lower_bound = weight_lower_bound
self.weight_upper_bound = weight_upper_bound
def new_zero_attrs(self, state):
return jnp.array([0.0]) # weight = 0
@@ -36,6 +41,7 @@ class DefaultConn(BaseConn):
jax.random.normal(randkey, ()) * self.weight_init_std
+ self.weight_init_mean
)
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
return jnp.array([weight])
def mutate(self, state, randkey, attrs):
@@ -49,7 +55,7 @@ class DefaultConn(BaseConn):
self.weight_mutate_rate,
self.weight_replace_rate,
)
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
return jnp.array([weight])
def distance(self, state, attrs1, attrs2):

View File

@@ -47,9 +47,9 @@ class BiasNode(BaseNode):
if isinstance(activation_options, Callable):
activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
if aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
if activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean

View File

@@ -52,9 +52,9 @@ class DefaultNode(BaseNode):
if isinstance(activation_options, Callable):
activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
if aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
if activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean

View File

@@ -1,3 +1,4 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d
from .custom import CustomFuncFit

View File

@@ -0,0 +1,119 @@
from typing import Callable, Union, List, Tuple, Sequence
import jax
from jax import vmap, Array, numpy as jnp
import numpy as np
from .func_fit import FuncFit
class CustomFuncFit(FuncFit):
def __init__(
self,
func: Callable,
low_bounds: Union[List, Tuple, Array],
upper_bounds: Union[List, Tuple, Array],
method: str = "sample",
num_samples: int = 100,
step_size: Array = None,
*args,
**kwargs,
):
if isinstance(low_bounds, list) or isinstance(low_bounds, tuple):
low_bounds = np.array(low_bounds, dtype=np.float32)
if isinstance(upper_bounds, list) or isinstance(upper_bounds, tuple):
upper_bounds = np.array(upper_bounds, dtype=np.float32)
try:
out = func(low_bounds)
except Exception as e:
raise ValueError(f"func(low_bounds) raise an exception: {e}")
assert low_bounds.shape == upper_bounds.shape
assert method in {"sample", "grid"}
self.func = func
self.low_bounds = low_bounds
self.upper_bounds = upper_bounds
self.method = method
self.num_samples = num_samples
self.step_size = step_size
self.generate_dataset()
super().__init__(*args, **kwargs)
def generate_dataset(self):
if self.method == "sample":
assert (
self.num_samples > 0
), f"num_samples must be positive, got {self.num_samples}"
inputs = np.zeros(
(self.num_samples, self.low_bounds.shape[0]), dtype=np.float32
)
for i in range(self.low_bounds.shape[0]):
inputs[:, i] = np.random.uniform(
low=self.low_bounds[i],
high=self.upper_bounds[i],
size=(self.num_samples,),
)
elif self.method == "grid":
assert (
self.step_size is not None
), "step_size must be provided when method is 'grid'"
assert (
self.step_size.shape == self.low_bounds.shape
), "step_size must have the same shape as low_bounds"
assert np.all(self.step_size > 0), "step_size must be positive"
inputs = np.zeros((1, 1))
for i in range(self.low_bounds.shape[0]):
new_col = np.arange(
self.low_bounds[i], self.upper_bounds[i], self.step_size[i]
)
inputs = cartesian_product(inputs, new_col[:, None])
inputs = inputs[:, 1:]
else:
raise ValueError(f"Unknown method: {self.method}")
outputs = vmap(self.func)(inputs)
self.data_inputs = jnp.array(inputs)
self.data_outputs = jnp.array(outputs)
@property
def inputs(self):
return self.data_inputs
@property
def targets(self):
return self.data_outputs
@property
def input_shape(self):
return self.data_inputs.shape
@property
def output_shape(self):
return self.data_outputs.shape
def cartesian_product(arr1, arr2):
assert (
arr1.ndim == arr2.ndim
), "arr1 and arr2 must have the same number of dimensions"
assert arr1.ndim <= 2, "arr1 and arr2 must have at most 2 dimensions"
len1 = arr1.shape[0]
len2 = arr2.shape[0]
repeated_arr1 = np.repeat(arr1, len2, axis=0)
tiled_arr2 = np.tile(arr2, (len1, 1))
new_arr = np.concatenate((repeated_arr1, tiled_arr2), axis=1)
return new_arr