Add tutorials
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -114,4 +114,6 @@ cython_debug/
|
||||
# Other
|
||||
*.log
|
||||
*.pot
|
||||
*.mo
|
||||
*.mo
|
||||
|
||||
tutorials/.ipynb_checkpoints/*
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
132
tutorials/mutated_network.svg
Normal file
132
tutorials/mutated_network.svg
Normal file
@@ -0,0 +1,132 @@
|
||||
<?xml version="1.0" encoding="utf-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
|
||||
<metadata>
|
||||
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<cc:Work>
|
||||
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||
<dc:date>2025-01-30T16:52:11.134853</dc:date>
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:creator>
|
||||
<cc:Agent>
|
||||
<dc:title>Matplotlib v3.9.1, https://matplotlib.org/</dc:title>
|
||||
</cc:Agent>
|
||||
</dc:creator>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs>
|
||||
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
|
||||
</defs>
|
||||
<g id="figure_1">
|
||||
<g id="patch_1">
|
||||
<path d="M 0 345.6
|
||||
L 460.8 345.6
|
||||
L 460.8 0
|
||||
L 0 0
|
||||
z
|
||||
" style="fill: #ffffff"/>
|
||||
</g>
|
||||
<g id="axes_1">
|
||||
<g id="patch_2">
|
||||
<path d="M 48.095958 312.568974
|
||||
Q 230.401029 244.204573 411.659252 176.232739
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 410.324982 176.09229
|
||||
L 411.659252 176.232739
|
||||
L 410.746331 177.215885
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_3">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 135.193467 172.8 220.620902 172.8
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 219.420902 172.2
|
||||
L 220.620902 172.8
|
||||
L 219.420902 173.4
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_4">
|
||||
<path d="M 48.095958 33.031026
|
||||
Q 230.401029 101.395427 411.659252 169.367261
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 410.746331 168.384115
|
||||
L 411.659252 169.367261
|
||||
L 410.324982 169.50771
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_5">
|
||||
<path d="M 239.061222 172.8
|
||||
Q 325.60669 172.8 411.034125 172.8
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
<path d="M 409.834125 172.2
|
||||
L 411.034125 172.8
|
||||
L 409.834125 173.4
|
||||
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="PathCollection_1">
|
||||
<path d="M 39.986777 324.270171
|
||||
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
|
||||
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
|
||||
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
|
||||
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
|
||||
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
|
||||
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||
z
|
||||
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 181.460254
|
||||
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
|
||||
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
|
||||
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
|
||||
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||
z
|
||||
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 38.650337
|
||||
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
|
||||
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
|
||||
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
|
||||
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||
z
|
||||
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 230.4 181.460254
|
||||
C 232.696726 181.460254 234.899694 180.547755 236.523724 178.923724
|
||||
C 238.147755 177.299694 239.060254 175.096726 239.060254 172.8
|
||||
C 239.060254 170.503274 238.147755 168.300306 236.523724 166.676276
|
||||
C 234.899694 165.052245 232.696726 164.139746 230.4 164.139746
|
||||
C 228.103274 164.139746 225.900306 165.052245 224.276276 166.676276
|
||||
C 222.652245 168.300306 221.739746 170.503274 221.739746 172.8
|
||||
C 221.739746 175.096726 222.652245 177.299694 224.276276 178.923724
|
||||
C 225.900306 180.547755 228.103274 181.460254 230.4 181.460254
|
||||
z
|
||||
" clip-path="url(#p94424a64bf)" style="fill: #ffffff; stroke: #000000"/>
|
||||
<path d="M 420.813223 181.460254
|
||||
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
|
||||
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
|
||||
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
|
||||
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||
z
|
||||
" clip-path="url(#p94424a64bf)" style="fill: #0000ff; stroke: #000000"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="p94424a64bf">
|
||||
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.8 KiB |
112
tutorials/origin_network.svg
Normal file
112
tutorials/origin_network.svg
Normal file
@@ -0,0 +1,112 @@
|
||||
<?xml version="1.0" encoding="utf-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
|
||||
<metadata>
|
||||
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<cc:Work>
|
||||
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
|
||||
<dc:date>2025-01-30T16:52:09.785628</dc:date>
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:creator>
|
||||
<cc:Agent>
|
||||
<dc:title>Matplotlib v3.9.1, https://matplotlib.org/</dc:title>
|
||||
</cc:Agent>
|
||||
</dc:creator>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs>
|
||||
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
|
||||
</defs>
|
||||
<g id="figure_1">
|
||||
<g id="patch_1">
|
||||
<path d="M 0 345.6
|
||||
L 460.8 345.6
|
||||
L 460.8 0
|
||||
L 0 0
|
||||
z
|
||||
" style="fill: #ffffff"/>
|
||||
</g>
|
||||
<g id="axes_1">
|
||||
<g id="patch_2">
|
||||
<path d="M 48.095958 312.568974
|
||||
Q 230.401029 244.204573 411.659252 176.232739
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
<path d="M 410.324982 176.09229
|
||||
L 411.659252 176.232739
|
||||
L 410.746331 177.215885
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_3">
|
||||
<path d="M 48.647998 172.8
|
||||
Q 230.399113 172.8 411.032194 172.8
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
<path d="M 409.832194 172.2
|
||||
L 411.032194 172.8
|
||||
L 409.832194 173.4
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="patch_4">
|
||||
<path d="M 48.095958 33.031026
|
||||
Q 230.401029 101.395427 411.659252 169.367261
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
<path d="M 410.746331 168.384115
|
||||
L 411.659252 169.367261
|
||||
L 410.324982 169.50771
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
|
||||
</g>
|
||||
<g id="PathCollection_1">
|
||||
<path d="M 39.986777 324.270171
|
||||
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
|
||||
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
|
||||
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
|
||||
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
|
||||
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
|
||||
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
|
||||
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
|
||||
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
|
||||
z
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 181.460254
|
||||
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
|
||||
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
|
||||
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
|
||||
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
|
||||
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
|
||||
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
|
||||
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
|
||||
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
|
||||
z
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 39.986777 38.650337
|
||||
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
|
||||
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
|
||||
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
|
||||
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
|
||||
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
|
||||
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
|
||||
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
|
||||
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
|
||||
z
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
|
||||
<path d="M 420.813223 181.460254
|
||||
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
|
||||
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
|
||||
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
|
||||
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
|
||||
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
|
||||
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
|
||||
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
|
||||
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
|
||||
z
|
||||
" clip-path="url(#pa3b9f839b9)" style="fill: #0000ff; stroke: #000000"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="pa3b9f839b9">
|
||||
<rect x="0" y="0" width="460.8" height="345.6"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.8 KiB |
334
tutorials/tutorial-00-functional programming and state.ipynb
Normal file
334
tutorials/tutorial-00-functional programming and state.ipynb
Normal file
@@ -0,0 +1,334 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3b568ac2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial 0: Functional Programming and State"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "302f305f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Functional Programming\n",
|
||||
"\n",
|
||||
"TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n",
|
||||
"\n",
|
||||
"Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n",
|
||||
"\n",
|
||||
"1. **Pure Functions**: The same input always produces the same output, with no side effects.\n",
|
||||
"2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n",
|
||||
"3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee0c8749",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## State\n",
|
||||
"\n",
|
||||
"In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n",
|
||||
"\n",
|
||||
"Here are some usages about `State`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "557a2f24",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"state_a=State ({})\n",
|
||||
"state_b=State ({'a': 1, 'b': 2})\n",
|
||||
"state_b.a=1\n",
|
||||
"state_b.b=2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# import State\n",
|
||||
"from tensorneat.common import State\n",
|
||||
"\n",
|
||||
"# create a new state\n",
|
||||
"state_a = State() # no arguments\n",
|
||||
"state_b = State(a=1, b=2) # kwargs\n",
|
||||
"\n",
|
||||
"print(f\"{state_a=}\")\n",
|
||||
"print(f\"{state_b=}\")\n",
|
||||
"\n",
|
||||
"# get items from state, use dot notation\n",
|
||||
"print(f\"{state_b.a=}\")\n",
|
||||
"print(f\"{state_b.b=}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2fd169ea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"state_a=State ({'a': 1, 'b': 2})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# add new items to the state, use register\n",
|
||||
"state_a = state_a.register(a=1, b=2)\n",
|
||||
"print(f\"{state_a=}\")\n",
|
||||
"\n",
|
||||
"# We CANNOT register the existing item\n",
|
||||
"# state_a = state_a.register(a=1)\n",
|
||||
"# will raise ValueError(f\"Key {key} already exists in state\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "8fe0d395",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"state_a=State ({'a': 3, 'b': 4})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# update the value of an item, use update\n",
|
||||
"state_a = state_a.update(a=3, b=4)\n",
|
||||
"print(f\"{state_a=}\")\n",
|
||||
"\n",
|
||||
"# We CANNOT update the non-existing item\n",
|
||||
"# state_a = state_a.update(c=3)\n",
|
||||
"# will raise ValueError(f\"Key {key} does not exist in state\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b3ed0eed",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"origin_state=State ({'a': 1, 'b': 2})\n",
|
||||
"new_state=State ({'a': 3, 'b': 2})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# State is immutable! We always create a new state, rather than modifying the existing one.\n",
|
||||
"\n",
|
||||
"origin_state = State(a=1, b=2)\n",
|
||||
"new_state = origin_state.update(a=3)\n",
|
||||
"print(f\"{origin_state=}\") # origin_state is not changed\n",
|
||||
"print(f\"{new_state=}\")\n",
|
||||
"\n",
|
||||
"# We can not modify the state directly\n",
|
||||
"# origin_state.a = 3\n",
|
||||
"# will raise AttributeError: AttributeError(\"State is immutable\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "8c73e60c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# State can be used in JAX functions\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def func(state):\n",
|
||||
" c = state.a + state.b # fetch items from state\n",
|
||||
" state = state.update(a=c, b=c) # update items in state\n",
|
||||
" state = state.register(c=c) # add new item to state\n",
|
||||
" return state # return state\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"new_state = func(state_a)\n",
|
||||
"print(f\"{new_state=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "2bd732ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Save the state (use pickle) as file and load it.\n",
|
||||
"state = State(a=1, b=2, c=3, d=4)\n",
|
||||
"state.save(\"tutorial_0_state.pkl\")\n",
|
||||
"loaded_state = State.load(\"tutorial_0_state.pkl\")\n",
|
||||
"print(f\"{loaded_state=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "846b5374",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Objects in TensorNEAT\n",
|
||||
"\n",
|
||||
"In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n",
|
||||
"\n",
|
||||
"In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n",
|
||||
"\n",
|
||||
"For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "dac2bee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorneat.genome import DefaultGenome\n",
|
||||
"\n",
|
||||
"genome = DefaultGenome(\n",
|
||||
" num_inputs=3,\n",
|
||||
" num_outputs=1,\n",
|
||||
" max_nodes=5,\n",
|
||||
" max_conns=5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c196cc4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n",
|
||||
"\n",
|
||||
"To create a genome that can participate in calculation, we need to do following things."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7c597c04",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n",
|
||||
" [ 1. , 0.9807121 , 1. , 0. , 0. ],\n",
|
||||
" [ 2. , -0.8425486 , 1. , 0. , 0. ],\n",
|
||||
" [ 3. , -0.53765106, 1. , 0. , 0. ],\n",
|
||||
" [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n",
|
||||
"conns=Array([[0. , 3. , 0.785558 ],\n",
|
||||
" [1. , 3. , 2.3734226 ],\n",
|
||||
" [2. , 3. , 0.07902155],\n",
|
||||
" [ nan, nan, nan],\n",
|
||||
" [ nan, nan, nan]], dtype=float32, weak_type=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# setup the genome, let the genome class store some useful information in State\n",
|
||||
"state = genome.setup()\n",
|
||||
"\n",
|
||||
"# create a new genome\n",
|
||||
"randkey = jax.random.key(0)\n",
|
||||
"nodes, conns = genome.initialize(state, randkey)\n",
|
||||
"print(f\"{nodes=}\")\n",
|
||||
"print(f\"{conns=}\")\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "1d8ae137",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"outputs=Array([5.231817], dtype=float32, weak_type=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# calculate\n",
|
||||
"inputs = jax.numpy.array([1, 2, 3])\n",
|
||||
"\n",
|
||||
"transformed = genome.transform(state, nodes, conns)\n",
|
||||
"outputs = genome.forward(state, transformed, inputs)\n",
|
||||
"\n",
|
||||
"print(f\"{outputs=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f01ed5d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conclusion\n",
|
||||
"1. TensorNEAT use functional programming paradiam.\n",
|
||||
"2. TensorNEAT provides `State` to manage data.\n",
|
||||
"3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "jax_env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,21 +1,206 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c6d7313",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial 1: Genome\n",
|
||||
"The genome is the core component of TensorNEAT. It represents the network’s genotype."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "939c40b4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
|
||||
"execution_count": 66,
|
||||
"id": "3b1836c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
|
||||
"\n",
|
||||
"genome = DefaultGenome(\n",
|
||||
" num_inputs=3, # 3 inputs\n",
|
||||
" num_outputs=1, # 1 output\n",
|
||||
" max_nodes=5, # the network will have at most 5 nodes\n",
|
||||
" max_conns=10, # the network will have at most 10 connections\n",
|
||||
" node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
|
||||
" conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
|
||||
" mutation=DefaultMutation(\n",
|
||||
" node_add=0.9,\n",
|
||||
" node_delete=0.0,\n",
|
||||
" conn_add=0.9,\n",
|
||||
" conn_delete=0.0\n",
|
||||
" ) # high mutation rate for testing\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"state = genome.setup()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f64c565c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"id": "ef66ba76",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
|
||||
"batch_output=Array([[ 2.5477247 ],\n",
|
||||
" [ 2.2334106 ],\n",
|
||||
" [ 1.8713341 ],\n",
|
||||
" [-3.7539673 ],\n",
|
||||
" [ 1.5344429 ],\n",
|
||||
" [ 2.7640016 ],\n",
|
||||
" [ 0.5649997 ],\n",
|
||||
" [-0.32709932],\n",
|
||||
" [ 3.5273829 ],\n",
|
||||
" [-0.64774114]], dtype=float32, weak_type=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"# Initialize a network\n",
|
||||
"nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
|
||||
"\n",
|
||||
"# Network forward\n",
|
||||
"single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
|
||||
"transformed = genome.transform(state, nodes, conns)\n",
|
||||
"output = genome.forward(state, transformed, single_input)\n",
|
||||
"print(f\"{output=}\")\n",
|
||||
"\n",
|
||||
"# Network batch forward\n",
|
||||
"batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
|
||||
"batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
|
||||
"print(f\"{batch_output=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"id": "85433b5e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"pop_outputs.shape=(1000, 1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Initialize a population of networks (1000 networks)\n",
|
||||
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
|
||||
" state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Population forward\n",
|
||||
"pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
|
||||
"pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
|
||||
" state, pop_nodes, pop_conns\n",
|
||||
")\n",
|
||||
"pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
|
||||
" state, pop_transformed, pop_inputs\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"{pop_outputs.shape=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"id": "f1f9df60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# visualize the network\n",
|
||||
"network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
|
||||
"genome.visualize(network, save_path=\"./origin_network.svg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da5ad8ae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"id": "16966b69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Mutate the network\n",
|
||||
"mutated_nodes, mutated_conns = genome.execute_mutation(\n",
|
||||
" state,\n",
|
||||
" jax.random.PRNGKey(2),\n",
|
||||
" nodes,\n",
|
||||
" conns,\n",
|
||||
" new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
|
||||
" new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Visualize the mutated network\n",
|
||||
"mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
|
||||
"genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4367424",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f9922d34",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "",
|
||||
"name": ""
|
||||
"display_name": "jax_env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": ""
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
528
tutorials/tutorial-02-Usage.ipynb
Normal file
528
tutorials/tutorial-02-Usage.ipynb
Normal file
@@ -0,0 +1,528 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial 2: Usage\n",
|
||||
"\n",
|
||||
"This tutorial introduces how to use TensorNEAT to solve problems. \n",
|
||||
"\n",
|
||||
"TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n",
|
||||
"Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n",
|
||||
"The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"## Types of Problems in TensorNEAT \n",
|
||||
"\n",
|
||||
"The problems to be solved using TensorNEAT can be categorized into the following cases:\n",
|
||||
"\n",
|
||||
"1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n",
|
||||
" - In this case, users can directly create a pipeline and execute it.\n",
|
||||
"\n",
|
||||
"2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n",
|
||||
" - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n",
|
||||
"\n",
|
||||
"3. **Problems not provided by TensorNEAT and not JIT-compatible** \n",
|
||||
" - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n",
|
||||
" - The detailed method for manual execution is explained below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2.1 Using Existing Benchmarks\n",
|
||||
"\n",
|
||||
"TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n",
|
||||
"\n",
|
||||
"If you want to use these predefined problems, refer to the **examples** for implementation details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2.2 Custom Jitable Problem\n",
|
||||
"The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 90,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Prepartion\n",
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"from tensorneat.problem import BaseProblem\n",
|
||||
"\n",
|
||||
"# The problem is to fit pagie_polynomial\n",
|
||||
"def pagie_polynomial(inputs):\n",
|
||||
" x, y = inputs\n",
|
||||
" res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n",
|
||||
"\n",
|
||||
" # Important! Returns an array with one item, NOT a scalar\n",
|
||||
" return jnp.array([res])\n",
|
||||
"\n",
|
||||
"# Create dataset (10 samples)\n",
|
||||
"INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n",
|
||||
"LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n",
|
||||
"\n",
|
||||
"print(f\"{INPUTS.shape=}, {LABELS.shape=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 91,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the custom Problem\n",
|
||||
"class CustomProblem(BaseProblem):\n",
|
||||
"\n",
|
||||
" jitable = True # necessary\n",
|
||||
"\n",
|
||||
" def evaluate(self, state, randkey, act_func, params):\n",
|
||||
" # Use ``act_func(state, params, inputs)`` to do network forward\n",
|
||||
"\n",
|
||||
" # do batch forward for all inputs (using jax.vamp)\n",
|
||||
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
|
||||
" state, params, INPUTS\n",
|
||||
" ) # should be shape (1000, 1)\n",
|
||||
"\n",
|
||||
" # calculate loss\n",
|
||||
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
|
||||
"\n",
|
||||
" # return negative loss as fitness \n",
|
||||
" # TensorNEAT maximizes fitness, equivalent to minimizes loss\n",
|
||||
" return -loss\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_shape(self):\n",
|
||||
" # the input shape that the act_func expects\n",
|
||||
" return (2, )\n",
|
||||
" \n",
|
||||
" @property\n",
|
||||
" def output_shape(self):\n",
|
||||
" # the output shape that the act_func returns\n",
|
||||
" return (1, )\n",
|
||||
" \n",
|
||||
" def show(self, state, randkey, act_func, params, *args, **kwargs):\n",
|
||||
" # shocase the performance of one individual\n",
|
||||
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
|
||||
" state, params, INPUTS\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
|
||||
"\n",
|
||||
" msg = \"\"\n",
|
||||
" for i in range(INPUTS.shape[0]):\n",
|
||||
" msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n",
|
||||
" msg += f\"loss: {loss}\\n\"\n",
|
||||
" print(msg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 92,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"initializing\n",
|
||||
"initializing finished\n",
|
||||
"start compile\n",
|
||||
"compile finished, cost time: 10.318278s\n",
|
||||
"Generation: 1, Cost time: 45.01ms\n",
|
||||
" \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
|
||||
"\n",
|
||||
"\tnode counts: max: 4, min: 3, mean: 3.10\n",
|
||||
" \tconn counts: max: 3, min: 0, mean: 1.88\n",
|
||||
" \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
|
||||
"\n",
|
||||
"Generation: 2, Cost time: 53.77ms\n",
|
||||
" \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
|
||||
"\n",
|
||||
"\tnode counts: max: 5, min: 3, mean: 3.17\n",
|
||||
" \tconn counts: max: 4, min: 0, mean: 1.87\n",
|
||||
" \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
|
||||
"\n",
|
||||
"Generation: 3, Cost time: 21.86ms\n",
|
||||
" \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
|
||||
"\n",
|
||||
"\tnode counts: max: 6, min: 3, mean: 3.49\n",
|
||||
" \tconn counts: max: 6, min: 0, mean: 2.47\n",
|
||||
" \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
|
||||
"\n",
|
||||
"Generation: 4, Cost time: 24.30ms\n",
|
||||
" \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
|
||||
"\n",
|
||||
"\tnode counts: max: 6, min: 3, mean: 3.76\n",
|
||||
" \tconn counts: max: 6, min: 0, mean: 2.66\n",
|
||||
" \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
|
||||
"\n",
|
||||
"Generation: 5, Cost time: 27.68ms\n",
|
||||
" \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
|
||||
"\n",
|
||||
"\tnode counts: max: 6, min: 3, mean: 3.94\n",
|
||||
" \tconn counts: max: 6, min: 0, mean: 2.80\n",
|
||||
" \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n",
|
||||
"\n",
|
||||
"Generation limit reached!\n",
|
||||
"input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n",
|
||||
"input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n",
|
||||
"input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n",
|
||||
"input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n",
|
||||
"input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n",
|
||||
"input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n",
|
||||
"input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n",
|
||||
"input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n",
|
||||
"input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n",
|
||||
"input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n",
|
||||
"loss: 0.0055083888582885265\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"from tensorneat.pipeline import Pipeline\n",
|
||||
"from tensorneat.algorithm.neat import NEAT\n",
|
||||
"from tensorneat.genome import DefaultGenome, BiasNode\n",
|
||||
"from tensorneat.problem.func_fit import CustomFuncFit\n",
|
||||
"from tensorneat.common import ACT, AGG\n",
|
||||
"\n",
|
||||
"# Construct the pipeline and run\n",
|
||||
"pipeline = Pipeline(\n",
|
||||
" algorithm=NEAT(\n",
|
||||
" pop_size=1000,\n",
|
||||
" species_size=20,\n",
|
||||
" survival_threshold=0.01,\n",
|
||||
" genome=DefaultGenome(\n",
|
||||
" num_inputs=2,\n",
|
||||
" num_outputs=1,\n",
|
||||
" init_hidden_layers=(),\n",
|
||||
" node_gene=BiasNode(\n",
|
||||
" activation_options=[ACT.identity, ACT.inv],\n",
|
||||
" aggregation_options=[AGG.sum, AGG.product],\n",
|
||||
" ),\n",
|
||||
" output_transform=ACT.identity,\n",
|
||||
" ),\n",
|
||||
" ),\n",
|
||||
" problem=CustomProblem(),\n",
|
||||
" generation_limit=5,\n",
|
||||
" fitness_target=-1e-4,\n",
|
||||
" seed=42,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# initialize state\n",
|
||||
"state = pipeline.setup()\n",
|
||||
"# run until terminate\n",
|
||||
"state, best = pipeline.auto_run(state)\n",
|
||||
"# show result\n",
|
||||
"pipeline.show(state, best)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.3 Custom Un-Jitable Problem \n",
|
||||
"This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 93,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We use Cartpole in gymnasium as the Un-jitable problem\n",
|
||||
"import gymnasium as gym\n",
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"\n",
|
||||
"from tensorneat.common import State\n",
|
||||
"# Define the genome and Pre jit necessary functions in genome\n",
|
||||
"genome=DefaultGenome(\n",
|
||||
" num_inputs=4,\n",
|
||||
" num_outputs=2,\n",
|
||||
" init_hidden_layers=(),\n",
|
||||
" node_gene=BiasNode(),\n",
|
||||
" output_transform=jnp.argmax,\n",
|
||||
")\n",
|
||||
"state = State(randkey=jax.random.key(0))\n",
|
||||
"state = genome.setup(state)\n",
|
||||
"\n",
|
||||
"jit_transform = jax.jit(genome.transform)\n",
|
||||
"jit_forward = jax.jit(genome.forward)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 94,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# define the method to evaluate the individual and the population\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"def evaluate(state, nodes, conns):\n",
|
||||
" # evaluate the individual\n",
|
||||
" transformed = jit_transform(state, nodes, conns)\n",
|
||||
"\n",
|
||||
" observation, info = env.reset()\n",
|
||||
" episode_over, total_reward = False, 0\n",
|
||||
" while not episode_over:\n",
|
||||
" action = jit_forward(state, transformed, observation)\n",
|
||||
" # currently the action is a jax array on gpu\n",
|
||||
" # we need move it to cpu for env step\n",
|
||||
" action = jax.device_get(action)\n",
|
||||
"\n",
|
||||
" observation, reward, terminated, truncated, info = env.step(action)\n",
|
||||
" total_reward += reward\n",
|
||||
" episode_over = terminated or truncated\n",
|
||||
"\n",
|
||||
" return total_reward\n",
|
||||
"\n",
|
||||
"def evaluate_population(state, pop_nodes, pop_conns):\n",
|
||||
" # evaluate the population\n",
|
||||
" pop_size = pop_nodes.shape[0]\n",
|
||||
" fitness = []\n",
|
||||
" for i in tqdm(range(pop_size)):\n",
|
||||
" fitness.append(\n",
|
||||
" evaluate(state, pop_nodes[i], pop_conns[i])\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # return a jax array\n",
|
||||
" return jnp.asarray(fitness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 95,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# define the algorithm\n",
|
||||
"algorithm = NEAT(\n",
|
||||
" pop_size=100,\n",
|
||||
" species_size=20,\n",
|
||||
" survival_threshold=0.1,\n",
|
||||
" genome=genome,\n",
|
||||
")\n",
|
||||
"state = algorithm.setup(state)\n",
|
||||
"\n",
|
||||
"# jit for acceleration\n",
|
||||
"jit_algorithm_ask = jax.jit(algorithm.ask)\n",
|
||||
"jit_algorithm_tell = jax.jit(algorithm.tell)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 96,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Start running...\n",
|
||||
"Generation 0: evaluating population...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generation 0: best fitness: 384.0\n",
|
||||
"Generation 1: evaluating population...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generation 1: best fitness: 500.0\n",
|
||||
"Fitness limit reached!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# run!\n",
|
||||
"print(\"Start running...\")\n",
|
||||
"for generation in range(10):\n",
|
||||
" pop_nodes, pop_conns = jit_algorithm_ask(state)\n",
|
||||
" print(f\"Generation {generation}: evaluating population...\")\n",
|
||||
" fitness = evaluate_population(state, pop_nodes, pop_conns)\n",
|
||||
"\n",
|
||||
" state = jit_algorithm_tell(state, fitness)\n",
|
||||
" print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n",
|
||||
"\n",
|
||||
" if fitness.max() >= 500:\n",
|
||||
" print(\"Fitness limit reached!\")\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above code runs slowly due to the following reasons:\n",
|
||||
"1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n",
|
||||
"2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.\n",
|
||||
"3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following code demonstrates an optimized `gymnasium` evaluation process:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 97,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use numpy as numpy-python switch takes shorter time than jax-python switch\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n",
|
||||
"jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n",
|
||||
"\n",
|
||||
"POP_SIZE = 100\n",
|
||||
"# Use multiple envs\n",
|
||||
"envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n",
|
||||
"def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n",
|
||||
" # transformed the population using batch transfrom\n",
|
||||
" pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n",
|
||||
"\n",
|
||||
" pop_observation = [env.reset()[0] for env in envs]\n",
|
||||
" pop_observation = np.asarray(pop_observation)\n",
|
||||
" pop_fitness = np.zeros(POP_SIZE)\n",
|
||||
" episode_over = np.zeros(POP_SIZE, dtype=bool)\n",
|
||||
" \n",
|
||||
" while not np.all(episode_over):\n",
|
||||
" # batch forward\n",
|
||||
" pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n",
|
||||
" pop_action = jax.device_get(pop_action)\n",
|
||||
"\n",
|
||||
" obs, reward, terminated, truncated = [], [], [], []\n",
|
||||
" # we still need to step the envs one by one\n",
|
||||
" for i in range(POP_SIZE):\n",
|
||||
" obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
|
||||
" obs.append(obs_)\n",
|
||||
" reward.append(reward_)\n",
|
||||
" terminated.append(terminated_)\n",
|
||||
" truncated.append(truncated_)\n",
|
||||
"\n",
|
||||
" pop_observation = np.asarray(obs)\n",
|
||||
" pop_reward = np.asarray(reward)\n",
|
||||
" pop_terminated = np.asarray(terminated)\n",
|
||||
" pop_truncated = np.asarray(truncated)\n",
|
||||
"\n",
|
||||
" # update fitness and over\n",
|
||||
" pop_fitness += pop_reward * ~episode_over\n",
|
||||
" episode_over = episode_over | pop_terminated | pop_truncated\n",
|
||||
"\n",
|
||||
" return pop_fitness"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 98,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n",
|
||||
" logger.warn(\n",
|
||||
"100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"slow_time=14.704506635665894, fast_time=1.5114572048187256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(14.704506635665894, 1.5114572048187256)"
|
||||
]
|
||||
},
|
||||
"execution_count": 98,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Compare the speed between these two methods\n",
|
||||
"# prerun once for jax compile\n",
|
||||
"accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
||||
"\n",
|
||||
"import time\n",
|
||||
"time_tic = time.time()\n",
|
||||
"fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n",
|
||||
"slow_time = time.time() - time_tic\n",
|
||||
"\n",
|
||||
"time_tic = time.time()\n",
|
||||
"fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
||||
"fast_time = time.time() - time_tic\n",
|
||||
"\n",
|
||||
"print(f\"{slow_time=}, {fast_time=}\")\n",
|
||||
"slow_time, fast_time"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "jax_env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
BIN
tutorials/tutorial_0_state.pkl
Normal file
BIN
tutorials/tutorial_0_state.pkl
Normal file
Binary file not shown.
Reference in New Issue
Block a user