Add CustomFuncFit into problem; Add related examples
This commit is contained in:
56
examples/func_fit/custom_func_fit.py
Normal file
56
examples/func_fit/custom_func_fit.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode
|
||||
from tensorneat.problem.func_fit import CustomFuncFit
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
|
||||
def pagie_polynomial(inputs):
|
||||
x, y = inputs
|
||||
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
|
||||
|
||||
# important! returns an array, NOT a scalar
|
||||
return jnp.array([res])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
custom_problem = CustomFuncFit(
|
||||
func=pagie_polynomial,
|
||||
low_bounds=[-1, -1],
|
||||
upper_bounds=[1, 1],
|
||||
method="sample",
|
||||
num_samples=1000,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
survival_threshold=0.01,
|
||||
genome=DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
node_gene=BiasNode(
|
||||
activation_options=[Act.identity, Act.inv, Act.square],
|
||||
aggregation_options=[Agg.sum, Agg.product],
|
||||
),
|
||||
output_transform=Act.identity,
|
||||
),
|
||||
),
|
||||
problem=custom_problem,
|
||||
generation_limit=100,
|
||||
fitness_target=-1e-4,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
# pipeline.show(state, best)
|
||||
print(pipeline.algorithm.genome.repr(state, *best))
|
||||
@@ -1,16 +1,39 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
arr = jnp.ones((10, 10))
|
||||
a = jnp.array([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
])
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode
|
||||
from tensorneat.problem.func_fit import CustomFuncFit
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
def attach_with_inf(arr, idx):
|
||||
target_dim = arr.ndim + idx.ndim - 1
|
||||
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||
|
||||
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
|
||||
def pagie_polynomial(inputs):
|
||||
x, y = inputs
|
||||
return x + y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
genome=DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=3,
|
||||
max_conns=2,
|
||||
init_hidden_layers=(),
|
||||
node_gene=BiasNode(
|
||||
activation_options=[Act.identity],
|
||||
aggregation_options=[Agg.sum],
|
||||
),
|
||||
output_transform=Act.identity,
|
||||
mutation=DefaultMutation(
|
||||
node_add=0,
|
||||
node_delete=0,
|
||||
conn_add=0.0,
|
||||
conn_delete=0.0,
|
||||
)
|
||||
)
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
print(genome)
|
||||
|
||||
|
||||
b = attach_with_inf(arr, a)
|
||||
print(b)
|
||||
415
network.svg
415
network.svg
@@ -1,415 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
|
||||
<metadata>
|
||||
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<cc:Work>
|
||||
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||
<dc:date>2024-07-10T19:47:34.359228</dc:date>
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:creator>
|
||||
<cc:Agent>
|
||||
<dc:title>Matplotlib v3.9.0, https://matplotlib.org/</dc:title>
|
||||
</cc:Agent>
|
||||
</dc:creator>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs>
|
||||
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
|
||||
</defs>
|
||||
<g id="figure_1">
|
||||
<g id="patch_1">
|
||||
<path d="M 0 345.6
|
||||
L 460.8 345.6
|
||||
L 460.8 0
|
||||
L 0 0
|
||||
z
|
||||
" style="fill: #ffffff"/>
|
||||
</g>
|
||||
<g id="axes_1">
|
||||
<g id="patch_2">
|
||||
<path d="M 44.79098 308.403612
|
||||
Q 87.590594 244.204191 129.770035 180.93503
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 125.887134 183.153831
|
||||
L 129.770035 180.93503
|
||||
L 129.215335 185.372632
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_3">
|
||||
<path d="M 46.916335 239.00779
|
||||
Q 87.591722 208.50125 127.372682 178.665529
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 122.972682 179.465529
|
||||
L 127.372682 178.665529
|
||||
L 125.372682 182.665529
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_4">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 87.590519 172.8 125.415005 172.8
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 121.415005 170.8
|
||||
L 125.415005 172.8
|
||||
L 121.415005 174.8
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_5">
|
||||
<path d="M 46.916335 106.59221
|
||||
Q 87.591722 137.09875 127.372682 166.934471
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 125.372682 162.934471
|
||||
L 127.372682 166.934471
|
||||
L 122.972682 166.134471
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_6">
|
||||
<path d="M 44.79098 37.196388
|
||||
Q 87.590594 101.395809 129.770035 164.66497
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 129.215335 160.227368
|
||||
L 129.770035 164.66497
|
||||
L 125.887134 162.446169
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_7">
|
||||
<path d="M 143.30257 175.840943
|
||||
Q 182.796502 190.651168 221.243586 205.068824
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 218.200516 201.791672
|
||||
L 221.243586 205.068824
|
||||
L 216.796023 205.536989
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_8">
|
||||
<path d="M 143.30257 169.759057
|
||||
Q 182.796502 154.948832 221.243586 140.531176
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 216.796023 140.063011
|
||||
L 221.243586 140.531176
|
||||
L 218.200516 143.808328
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_9">
|
||||
<path d="M 238.509181 211.543422
|
||||
Q 278.003113 226.353647 316.450198 240.771303
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 313.407128 237.494151
|
||||
L 316.450198 240.771303
|
||||
L 312.002634 241.239468
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_10">
|
||||
<path d="M 238.509181 205.461536
|
||||
Q 278.003113 190.651312 316.450198 176.233655
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 312.002634 175.765491
|
||||
L 316.450198 176.233655
|
||||
L 313.407128 179.510807
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_11">
|
||||
<path d="M 236.155746 202.027265
|
||||
Q 278.00531 154.946506 319.112092 108.701376
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 314.959818 110.362286
|
||||
L 319.112092 108.701376
|
||||
L 317.949455 113.019741
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_12">
|
||||
<path d="M 236.155746 143.572735
|
||||
Q 278.00531 190.653494 319.112092 236.898624
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 317.949455 232.580259
|
||||
L 319.112092 236.898624
|
||||
L 314.959818 235.237714
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_13">
|
||||
<path d="M 238.509181 140.138464
|
||||
Q 278.003113 154.948688 316.450198 169.366345
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 313.407128 166.089193
|
||||
L 316.450198 169.366345
|
||||
L 312.002634 169.834509
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_14">
|
||||
<path d="M 238.509181 134.056578
|
||||
Q 278.003113 119.246353 316.450198 104.828697
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 312.002634 104.360532
|
||||
L 316.450198 104.828697
|
||||
L 313.407128 108.105849
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_15">
|
||||
<path d="M 334.267833 244.204959
|
||||
Q 373.210353 244.204959 411.03484 244.204959
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 242.204959
|
||||
L 411.03484 244.204959
|
||||
L 407.03484 246.204959
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_16">
|
||||
<path d="M 332.53617 239.00779
|
||||
Q 373.211557 208.50125 412.992517 178.665529
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 408.592517 179.465529
|
||||
L 412.992517 178.665529
|
||||
L 410.992517 182.665529
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_17">
|
||||
<path d="M 330.410815 236.998654
|
||||
Q 373.210429 172.799232 415.38987 109.530071
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 411.506968 111.748872
|
||||
L 415.38987 109.530071
|
||||
L 414.83517 113.967673
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_18">
|
||||
<path d="M 332.53617 177.997169
|
||||
Q 373.211557 208.503709 412.992517 238.339429
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 410.992517 234.339429
|
||||
L 412.992517 238.339429
|
||||
L 408.592517 237.539429
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_19">
|
||||
<path d="M 334.267833 172.8
|
||||
Q 373.210353 172.8 411.03484 172.8
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 170.8
|
||||
L 411.03484 172.8
|
||||
L 407.03484 174.8
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_20">
|
||||
<path d="M 332.53617 167.602831
|
||||
Q 373.211557 137.096291 412.992517 107.260571
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 408.592517 108.060571
|
||||
L 412.992517 107.260571
|
||||
L 410.992517 111.260571
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_21">
|
||||
<path d="M 330.410815 108.601346
|
||||
Q 373.210429 172.800768 415.38987 236.069929
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 414.83517 231.632327
|
||||
L 415.38987 236.069929
|
||||
L 411.506968 233.851128
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_22">
|
||||
<path d="M 332.53617 106.59221
|
||||
Q 373.211557 137.09875 412.992517 166.934471
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 410.992517 162.934471
|
||||
L 412.992517 166.934471
|
||||
L 408.592517 166.134471
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_23">
|
||||
<path d="M 334.267833 101.395041
|
||||
Q 373.210353 101.395041 411.03484 101.395041
|
||||
" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/>
|
||||
<path d="M 407.03484 99.395041
|
||||
L 411.03484 101.395041
|
||||
L 407.03484 103.395041
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="PathCollection_1">
|
||||
<path d="M 39.986777 324.270171
|
||||
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
|
||||
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
|
||||
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
|
||||
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
|
||||
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
|
||||
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 252.865213
|
||||
C 42.283503 252.865213 44.486471 251.952714 46.110501 250.328683
|
||||
C 47.734532 248.704652 48.647031 246.501685 48.647031 244.204959
|
||||
C 48.647031 241.908232 47.734532 239.705265 46.110501 238.081234
|
||||
C 44.486471 236.457204 42.283503 235.544705 39.986777 235.544705
|
||||
C 37.690051 235.544705 35.487083 236.457204 33.863053 238.081234
|
||||
C 32.239022 239.705265 31.326523 241.908232 31.326523 244.204959
|
||||
C 31.326523 246.501685 32.239022 248.704652 33.863053 250.328683
|
||||
C 35.487083 251.952714 37.690051 252.865213 39.986777 252.865213
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 181.460254
|
||||
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
|
||||
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
|
||||
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
|
||||
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 110.055295
|
||||
C 42.283503 110.055295 44.486471 109.142796 46.110501 107.518766
|
||||
C 47.734532 105.894735 48.647031 103.691768 48.647031 101.395041
|
||||
C 48.647031 99.098315 47.734532 96.895348 46.110501 95.271317
|
||||
C 44.486471 93.647286 42.283503 92.734787 39.986777 92.734787
|
||||
C 37.690051 92.734787 35.487083 93.647286 33.863053 95.271317
|
||||
C 32.239022 96.895348 31.326523 99.098315 31.326523 101.395041
|
||||
C 31.326523 103.691768 32.239022 105.894735 33.863053 107.518766
|
||||
C 35.487083 109.142796 37.690051 110.055295 39.986777 110.055295
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 39.986777 38.650337
|
||||
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
|
||||
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
|
||||
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
|
||||
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 135.193388 181.460254
|
||||
C 137.490115 181.460254 139.693082 180.547755 141.317113 178.923724
|
||||
C 142.941143 177.299694 143.853642 175.096726 143.853642 172.8
|
||||
C 143.853642 170.503274 142.941143 168.300306 141.317113 166.676276
|
||||
C 139.693082 165.052245 137.490115 164.139746 135.193388 164.139746
|
||||
C 132.896662 164.139746 130.693695 165.052245 129.069664 166.676276
|
||||
C 127.445633 168.300306 126.533134 170.503274 126.533134 172.8
|
||||
C 126.533134 175.096726 127.445633 177.299694 129.069664 178.923724
|
||||
C 130.693695 180.547755 132.896662 181.460254 135.193388 181.460254
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 230.4 217.162733
|
||||
C 232.696726 217.162733 234.899694 216.250234 236.523724 214.626204
|
||||
C 238.147755 213.002173 239.060254 210.799206 239.060254 208.502479
|
||||
C 239.060254 206.205753 238.147755 204.002786 236.523724 202.378755
|
||||
C 234.899694 200.754724 232.696726 199.842225 230.4 199.842225
|
||||
C 228.103274 199.842225 225.900306 200.754724 224.276276 202.378755
|
||||
C 222.652245 204.002786 221.739746 206.205753 221.739746 208.502479
|
||||
C 221.739746 210.799206 222.652245 213.002173 224.276276 214.626204
|
||||
C 225.900306 216.250234 228.103274 217.162733 230.4 217.162733
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 230.4 145.757775
|
||||
C 232.696726 145.757775 234.899694 144.845276 236.523724 143.221245
|
||||
C 238.147755 141.597214 239.060254 139.394247 239.060254 137.097521
|
||||
C 239.060254 134.800794 238.147755 132.597827 236.523724 130.973796
|
||||
C 234.899694 129.349766 232.696726 128.437267 230.4 128.437267
|
||||
C 228.103274 128.437267 225.900306 129.349766 224.276276 130.973796
|
||||
C 222.652245 132.597827 221.739746 134.800794 221.739746 137.097521
|
||||
C 221.739746 139.394247 222.652245 141.597214 224.276276 143.221245
|
||||
C 225.900306 144.845276 228.103274 145.757775 230.4 145.757775
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 252.865213
|
||||
C 327.903338 252.865213 330.106305 251.952714 331.730336 250.328683
|
||||
C 333.354367 248.704652 334.266866 246.501685 334.266866 244.204959
|
||||
C 334.266866 241.908232 333.354367 239.705265 331.730336 238.081234
|
||||
C 330.106305 236.457204 327.903338 235.544705 325.606612 235.544705
|
||||
C 323.309885 235.544705 321.106918 236.457204 319.482887 238.081234
|
||||
C 317.858857 239.705265 316.946358 241.908232 316.946358 244.204959
|
||||
C 316.946358 246.501685 317.858857 248.704652 319.482887 250.328683
|
||||
C 321.106918 251.952714 323.309885 252.865213 325.606612 252.865213
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 181.460254
|
||||
C 327.903338 181.460254 330.106305 180.547755 331.730336 178.923724
|
||||
C 333.354367 177.299694 334.266866 175.096726 334.266866 172.8
|
||||
C 334.266866 170.503274 333.354367 168.300306 331.730336 166.676276
|
||||
C 330.106305 165.052245 327.903338 164.139746 325.606612 164.139746
|
||||
C 323.309885 164.139746 321.106918 165.052245 319.482887 166.676276
|
||||
C 317.858857 168.300306 316.946358 170.503274 316.946358 172.8
|
||||
C 316.946358 175.096726 317.858857 177.299694 319.482887 178.923724
|
||||
C 321.106918 180.547755 323.309885 181.460254 325.606612 181.460254
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 325.606612 110.055295
|
||||
C 327.903338 110.055295 330.106305 109.142796 331.730336 107.518766
|
||||
C 333.354367 105.894735 334.266866 103.691768 334.266866 101.395041
|
||||
C 334.266866 99.098315 333.354367 96.895348 331.730336 95.271317
|
||||
C 330.106305 93.647286 327.903338 92.734787 325.606612 92.734787
|
||||
C 323.309885 92.734787 321.106918 93.647286 319.482887 95.271317
|
||||
C 317.858857 96.895348 316.946358 99.098315 316.946358 101.395041
|
||||
C 316.946358 103.691768 317.858857 105.894735 319.482887 107.518766
|
||||
C 321.106918 109.142796 323.309885 110.055295 325.606612 110.055295
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 252.865213
|
||||
C 423.109949 252.865213 425.312917 251.952714 426.936947 250.328683
|
||||
C 428.560978 248.704652 429.473477 246.501685 429.473477 244.204959
|
||||
C 429.473477 241.908232 428.560978 239.705265 426.936947 238.081234
|
||||
C 425.312917 236.457204 423.109949 235.544705 420.813223 235.544705
|
||||
C 418.516497 235.544705 416.313529 236.457204 414.689499 238.081234
|
||||
C 413.065468 239.705265 412.152969 241.908232 412.152969 244.204959
|
||||
C 412.152969 246.501685 413.065468 248.704652 414.689499 250.328683
|
||||
C 416.313529 251.952714 418.516497 252.865213 420.813223 252.865213
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 181.460254
|
||||
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
|
||||
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
|
||||
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
|
||||
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
<path d="M 420.813223 110.055295
|
||||
C 423.109949 110.055295 425.312917 109.142796 426.936947 107.518766
|
||||
C 428.560978 105.894735 429.473477 103.691768 429.473477 101.395041
|
||||
C 429.473477 99.098315 428.560978 96.895348 426.936947 95.271317
|
||||
C 425.312917 93.647286 423.109949 92.734787 420.813223 92.734787
|
||||
C 418.516497 92.734787 416.313529 93.647286 414.689499 95.271317
|
||||
C 413.065468 96.895348 412.152969 99.098315 412.152969 101.395041
|
||||
C 412.152969 103.691768 413.065468 105.894735 414.689499 107.518766
|
||||
C 416.313529 109.142796 418.516497 110.055295 420.813223 110.055295
|
||||
z
|
||||
" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="p572566e0dc">
|
||||
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 18 KiB |
@@ -76,7 +76,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conn(query_res)
|
||||
), self.substrate.make_conns(query_res)
|
||||
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ class BaseSubstrate(StatefulBaseClass):
|
||||
def make_nodes(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_conn(self, query_res):
|
||||
def make_conns(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from . import BaseSubstrate
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseSubstrate
|
||||
from tensorneat.genome.utils import set_conn_attrs
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
@@ -13,8 +15,9 @@ class DefaultSubstrate(BaseSubstrate):
|
||||
def make_nodes(self, query_res):
|
||||
return self.nodes
|
||||
|
||||
def make_conn(self, query_res):
|
||||
return self.conns.at[:, 2:].set(query_res) # change weight
|
||||
def make_conns(self, query_res):
|
||||
# change weight of conns
|
||||
return vmap(set_conn_attrs)(self.conns, query_res)
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
|
||||
@@ -31,6 +31,7 @@ name2sympy = {
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
"square": SympySquare,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -69,6 +69,10 @@ class Act:
|
||||
z = jnp.clip(z, -10, 10)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def square(z):
|
||||
return jnp.pow(z, 2)
|
||||
|
||||
@staticmethod
|
||||
def abs(z):
|
||||
z = jnp.clip(z, -1, 1)
|
||||
|
||||
@@ -184,6 +184,12 @@ class SympyExp(sp.Function):
|
||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySquare(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Pow(z, 2)
|
||||
|
||||
|
||||
class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
|
||||
@@ -17,6 +17,8 @@ class DefaultConn(BaseConn):
|
||||
weight_mutate_power: float = 0.15,
|
||||
weight_mutate_rate: float = 0.2,
|
||||
weight_replace_rate: float = 0.015,
|
||||
weight_lower_bound: float = -5.0,
|
||||
weight_upper_bound: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_init_mean = weight_init_mean
|
||||
@@ -24,6 +26,9 @@ class DefaultConn(BaseConn):
|
||||
self.weight_mutate_power = weight_mutate_power
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
self.weight_lower_bound = weight_lower_bound
|
||||
self.weight_upper_bound = weight_upper_bound
|
||||
|
||||
|
||||
def new_zero_attrs(self, state):
|
||||
return jnp.array([0.0]) # weight = 0
|
||||
@@ -36,6 +41,7 @@ class DefaultConn(BaseConn):
|
||||
jax.random.normal(randkey, ()) * self.weight_init_std
|
||||
+ self.weight_init_mean
|
||||
)
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
@@ -49,7 +55,7 @@ class DefaultConn(BaseConn):
|
||||
self.weight_mutate_rate,
|
||||
self.weight_replace_rate,
|
||||
)
|
||||
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
|
||||
@@ -47,9 +47,9 @@ class BiasNode(BaseNode):
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if len(aggregation_options) == 1 and aggregation_default is None:
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if len(activation_options) == 1 and activation_default is None:
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
|
||||
@@ -52,9 +52,9 @@ class DefaultNode(BaseNode):
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if len(aggregation_options) == 1 and aggregation_default is None:
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if len(activation_options) == 1 and activation_default is None:
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .func_fit import FuncFit
|
||||
from .xor import XOR
|
||||
from .xor3d import XOR3d
|
||||
from .custom import CustomFuncFit
|
||||
119
tensorneat/problem/func_fit/custom.py
Normal file
119
tensorneat/problem/func_fit/custom.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Callable, Union, List, Tuple, Sequence
|
||||
|
||||
import jax
|
||||
from jax import vmap, Array, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class CustomFuncFit(FuncFit):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable,
|
||||
low_bounds: Union[List, Tuple, Array],
|
||||
upper_bounds: Union[List, Tuple, Array],
|
||||
method: str = "sample",
|
||||
num_samples: int = 100,
|
||||
step_size: Array = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(low_bounds, list) or isinstance(low_bounds, tuple):
|
||||
low_bounds = np.array(low_bounds, dtype=np.float32)
|
||||
if isinstance(upper_bounds, list) or isinstance(upper_bounds, tuple):
|
||||
upper_bounds = np.array(upper_bounds, dtype=np.float32)
|
||||
|
||||
try:
|
||||
out = func(low_bounds)
|
||||
except Exception as e:
|
||||
raise ValueError(f"func(low_bounds) raise an exception: {e}")
|
||||
assert low_bounds.shape == upper_bounds.shape
|
||||
|
||||
assert method in {"sample", "grid"}
|
||||
|
||||
self.func = func
|
||||
self.low_bounds = low_bounds
|
||||
self.upper_bounds = upper_bounds
|
||||
|
||||
self.method = method
|
||||
self.num_samples = num_samples
|
||||
self.step_size = step_size
|
||||
|
||||
self.generate_dataset()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def generate_dataset(self):
|
||||
|
||||
if self.method == "sample":
|
||||
assert (
|
||||
self.num_samples > 0
|
||||
), f"num_samples must be positive, got {self.num_samples}"
|
||||
|
||||
inputs = np.zeros(
|
||||
(self.num_samples, self.low_bounds.shape[0]), dtype=np.float32
|
||||
)
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
inputs[:, i] = np.random.uniform(
|
||||
low=self.low_bounds[i],
|
||||
high=self.upper_bounds[i],
|
||||
size=(self.num_samples,),
|
||||
)
|
||||
elif self.method == "grid":
|
||||
assert (
|
||||
self.step_size is not None
|
||||
), "step_size must be provided when method is 'grid'"
|
||||
assert (
|
||||
self.step_size.shape == self.low_bounds.shape
|
||||
), "step_size must have the same shape as low_bounds"
|
||||
assert np.all(self.step_size > 0), "step_size must be positive"
|
||||
|
||||
inputs = np.zeros((1, 1))
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
new_col = np.arange(
|
||||
self.low_bounds[i], self.upper_bounds[i], self.step_size[i]
|
||||
)
|
||||
inputs = cartesian_product(inputs, new_col[:, None])
|
||||
inputs = inputs[:, 1:]
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
outputs = vmap(self.func)(inputs)
|
||||
|
||||
self.data_inputs = jnp.array(inputs)
|
||||
self.data_outputs = jnp.array(outputs)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self.data_inputs
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self.data_outputs
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.data_inputs.shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.data_outputs.shape
|
||||
|
||||
|
||||
def cartesian_product(arr1, arr2):
|
||||
assert (
|
||||
arr1.ndim == arr2.ndim
|
||||
), "arr1 and arr2 must have the same number of dimensions"
|
||||
assert arr1.ndim <= 2, "arr1 and arr2 must have at most 2 dimensions"
|
||||
|
||||
len1 = arr1.shape[0]
|
||||
len2 = arr2.shape[0]
|
||||
|
||||
repeated_arr1 = np.repeat(arr1, len2, axis=0)
|
||||
tiled_arr2 = np.tile(arr2, (len1, 1))
|
||||
|
||||
new_arr = np.concatenate((repeated_arr1, tiled_arr2), axis=1)
|
||||
return new_arr
|
||||
Reference in New Issue
Block a user