Add tutorials

This commit is contained in:
wls2002
2025-01-30 16:53:24 +08:00
parent ee1a2a8271
commit f67c69776a
8 changed files with 1299 additions and 12 deletions

View File

@@ -1,6 +0,0 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,132 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2025-01-30T16:52:11.134853</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.9.1, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 345.6
L 460.8 345.6
L 460.8 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 48.095958 312.568974
Q 230.401029 244.204573 411.659252 176.232739
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 410.324982 176.09229
L 411.659252 176.232739
L 410.746331 177.215885
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_3">
<path d="M 48.647998 172.8
Q 135.193467 172.8 220.620902 172.8
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 219.420902 172.2
L 220.620902 172.8
L 219.420902 173.4
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_4">
<path d="M 48.095958 33.031026
Q 230.401029 101.395427 411.659252 169.367261
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 410.746331 168.384115
L 411.659252 169.367261
L 410.324982 169.50771
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="patch_5">
<path d="M 239.061222 172.8
Q 325.60669 172.8 411.034125 172.8
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
<path d="M 409.834125 172.2
L 411.034125 172.8
L 409.834125 173.4
" clip-path="url(#p94424a64bf)" style="fill: none; stroke: #4c4c4c; stroke-linecap: round"/>
</g>
<g id="PathCollection_1">
<path d="M 39.986777 324.270171
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z
" clip-path="url(#p94424a64bf)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 230.4 181.460254
C 232.696726 181.460254 234.899694 180.547755 236.523724 178.923724
C 238.147755 177.299694 239.060254 175.096726 239.060254 172.8
C 239.060254 170.503274 238.147755 168.300306 236.523724 166.676276
C 234.899694 165.052245 232.696726 164.139746 230.4 164.139746
C 228.103274 164.139746 225.900306 165.052245 224.276276 166.676276
C 222.652245 168.300306 221.739746 170.503274 221.739746 172.8
C 221.739746 175.096726 222.652245 177.299694 224.276276 178.923724
C 225.900306 180.547755 228.103274 181.460254 230.4 181.460254
z
" clip-path="url(#p94424a64bf)" style="fill: #ffffff; stroke: #000000"/>
<path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z
" clip-path="url(#p94424a64bf)" style="fill: #0000ff; stroke: #000000"/>
</g>
</g>
</g>
<defs>
<clipPath id="p94424a64bf">
<rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 5.8 KiB

View File

@@ -0,0 +1,112 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="460.8pt" height="345.6pt" viewBox="0 0 460.8 345.6" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2025-01-30T16:52:09.785628</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.9.1, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 345.6
L 460.8 345.6
L 460.8 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 48.095958 312.568974
Q 230.401029 244.204573 411.659252 176.232739
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
<path d="M 410.324982 176.09229
L 411.659252 176.232739
L 410.746331 177.215885
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
</g>
<g id="patch_3">
<path d="M 48.647998 172.8
Q 230.399113 172.8 411.032194 172.8
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
<path d="M 409.832194 172.2
L 411.032194 172.8
L 409.832194 173.4
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
</g>
<g id="patch_4">
<path d="M 48.095958 33.031026
Q 230.401029 101.395427 411.659252 169.367261
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
<path d="M 410.746331 168.384115
L 411.659252 169.367261
L 410.324982 169.50771
" clip-path="url(#pa3b9f839b9)" style="fill: none; stroke: #440154; stroke-linecap: round"/>
</g>
<g id="PathCollection_1">
<path d="M 39.986777 324.270171
C 42.283503 324.270171 44.486471 323.357672 46.110501 321.733642
C 47.734532 320.109611 48.647031 317.906644 48.647031 315.609917
C 48.647031 313.313191 47.734532 311.110224 46.110501 309.486193
C 44.486471 307.862162 42.283503 306.949663 39.986777 306.949663
C 37.690051 306.949663 35.487083 307.862162 33.863053 309.486193
C 32.239022 311.110224 31.326523 313.313191 31.326523 315.609917
C 31.326523 317.906644 32.239022 320.109611 33.863053 321.733642
C 35.487083 323.357672 37.690051 324.270171 39.986777 324.270171
z
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 181.460254
C 42.283503 181.460254 44.486471 180.547755 46.110501 178.923724
C 47.734532 177.299694 48.647031 175.096726 48.647031 172.8
C 48.647031 170.503274 47.734532 168.300306 46.110501 166.676276
C 44.486471 165.052245 42.283503 164.139746 39.986777 164.139746
C 37.690051 164.139746 35.487083 165.052245 33.863053 166.676276
C 32.239022 168.300306 31.326523 170.503274 31.326523 172.8
C 31.326523 175.096726 32.239022 177.299694 33.863053 178.923724
C 35.487083 180.547755 37.690051 181.460254 39.986777 181.460254
z
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 39.986777 38.650337
C 42.283503 38.650337 44.486471 37.737838 46.110501 36.113807
C 47.734532 34.489776 48.647031 32.286809 48.647031 29.990083
C 48.647031 27.693356 47.734532 25.490389 46.110501 23.866358
C 44.486471 22.242328 42.283503 21.329829 39.986777 21.329829
C 37.690051 21.329829 35.487083 22.242328 33.863053 23.866358
C 32.239022 25.490389 31.326523 27.693356 31.326523 29.990083
C 31.326523 32.286809 32.239022 34.489776 33.863053 36.113807
C 35.487083 37.737838 37.690051 38.650337 39.986777 38.650337
z
" clip-path="url(#pa3b9f839b9)" style="fill: #ffff00; stroke: #000000"/>
<path d="M 420.813223 181.460254
C 423.109949 181.460254 425.312917 180.547755 426.936947 178.923724
C 428.560978 177.299694 429.473477 175.096726 429.473477 172.8
C 429.473477 170.503274 428.560978 168.300306 426.936947 166.676276
C 425.312917 165.052245 423.109949 164.139746 420.813223 164.139746
C 418.516497 164.139746 416.313529 165.052245 414.689499 166.676276
C 413.065468 168.300306 412.152969 170.503274 412.152969 172.8
C 412.152969 175.096726 413.065468 177.299694 414.689499 178.923724
C 416.313529 180.547755 418.516497 181.460254 420.813223 181.460254
z
" clip-path="url(#pa3b9f839b9)" style="fill: #0000ff; stroke: #000000"/>
</g>
</g>
</g>
<defs>
<clipPath id="pa3b9f839b9">
<rect x="0" y="0" width="460.8" height="345.6"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 4.8 KiB

View File

@@ -0,0 +1,334 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3b568ac2",
"metadata": {},
"source": [
"# Tutorial 0: Functional Programming and State"
]
},
{
"cell_type": "markdown",
"id": "302f305f",
"metadata": {},
"source": [
"## Functional Programming\n",
"\n",
"TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n",
"\n",
"Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n",
"\n",
"1. **Pure Functions**: The same input always produces the same output, with no side effects.\n",
"2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n",
"3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values."
]
},
{
"cell_type": "markdown",
"id": "ee0c8749",
"metadata": {},
"source": [
"## State\n",
"\n",
"In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n",
"\n",
"Here are some usages about `State`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "557a2f24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state_a=State ({})\n",
"state_b=State ({'a': 1, 'b': 2})\n",
"state_b.a=1\n",
"state_b.b=2\n"
]
}
],
"source": [
"# import State\n",
"from tensorneat.common import State\n",
"\n",
"# create a new state\n",
"state_a = State() # no arguments\n",
"state_b = State(a=1, b=2) # kwargs\n",
"\n",
"print(f\"{state_a=}\")\n",
"print(f\"{state_b=}\")\n",
"\n",
"# get items from state, use dot notation\n",
"print(f\"{state_b.a=}\")\n",
"print(f\"{state_b.b=}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2fd169ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state_a=State ({'a': 1, 'b': 2})\n"
]
}
],
"source": [
"# add new items to the state, use register\n",
"state_a = state_a.register(a=1, b=2)\n",
"print(f\"{state_a=}\")\n",
"\n",
"# We CANNOT register the existing item\n",
"# state_a = state_a.register(a=1)\n",
"# will raise ValueError(f\"Key {key} already exists in state\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8fe0d395",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state_a=State ({'a': 3, 'b': 4})\n"
]
}
],
"source": [
"# update the value of an item, use update\n",
"state_a = state_a.update(a=3, b=4)\n",
"print(f\"{state_a=}\")\n",
"\n",
"# We CANNOT update the non-existing item\n",
"# state_a = state_a.update(c=3)\n",
"# will raise ValueError(f\"Key {key} does not exist in state\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b3ed0eed",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"origin_state=State ({'a': 1, 'b': 2})\n",
"new_state=State ({'a': 3, 'b': 2})\n"
]
}
],
"source": [
"# State is immutable! We always create a new state, rather than modifying the existing one.\n",
"\n",
"origin_state = State(a=1, b=2)\n",
"new_state = origin_state.update(a=3)\n",
"print(f\"{origin_state=}\") # origin_state is not changed\n",
"print(f\"{new_state=}\")\n",
"\n",
"# We can not modify the state directly\n",
"# origin_state.a = 3\n",
"# will raise AttributeError: AttributeError(\"State is immutable\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8c73e60c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n"
]
}
],
"source": [
"# State can be used in JAX functions\n",
"import jax\n",
"\n",
"\n",
"@jax.jit\n",
"def func(state):\n",
" c = state.a + state.b # fetch items from state\n",
" state = state.update(a=c, b=c) # update items in state\n",
" state = state.register(c=c) # add new item to state\n",
" return state # return state\n",
"\n",
"\n",
"new_state = func(state_a)\n",
"print(f\"{new_state=}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2bd732ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n"
]
}
],
"source": [
"# Save the state (use pickle) as file and load it.\n",
"state = State(a=1, b=2, c=3, d=4)\n",
"state.save(\"tutorial_0_state.pkl\")\n",
"loaded_state = State.load(\"tutorial_0_state.pkl\")\n",
"print(f\"{loaded_state=}\")"
]
},
{
"cell_type": "markdown",
"id": "846b5374",
"metadata": {},
"source": [
"## Objects in TensorNEAT\n",
"\n",
"In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n",
"\n",
"In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n",
"\n",
"For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dac2bee6",
"metadata": {},
"outputs": [],
"source": [
"from tensorneat.genome import DefaultGenome\n",
"\n",
"genome = DefaultGenome(\n",
" num_inputs=3,\n",
" num_outputs=1,\n",
" max_nodes=5,\n",
" max_conns=5,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2c196cc4",
"metadata": {},
"source": [
"`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n",
"\n",
"To create a genome that can participate in calculation, we need to do following things."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7c597c04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n",
" [ 1. , 0.9807121 , 1. , 0. , 0. ],\n",
" [ 2. , -0.8425486 , 1. , 0. , 0. ],\n",
" [ 3. , -0.53765106, 1. , 0. , 0. ],\n",
" [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n",
"conns=Array([[0. , 3. , 0.785558 ],\n",
" [1. , 3. , 2.3734226 ],\n",
" [2. , 3. , 0.07902155],\n",
" [ nan, nan, nan],\n",
" [ nan, nan, nan]], dtype=float32, weak_type=True)\n"
]
}
],
"source": [
"# setup the genome, let the genome class store some useful information in State\n",
"state = genome.setup()\n",
"\n",
"# create a new genome\n",
"randkey = jax.random.key(0)\n",
"nodes, conns = genome.initialize(state, randkey)\n",
"print(f\"{nodes=}\")\n",
"print(f\"{conns=}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1d8ae137",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"outputs=Array([5.231817], dtype=float32, weak_type=True)\n"
]
}
],
"source": [
"# calculate\n",
"inputs = jax.numpy.array([1, 2, 3])\n",
"\n",
"transformed = genome.transform(state, nodes, conns)\n",
"outputs = genome.forward(state, transformed, inputs)\n",
"\n",
"print(f\"{outputs=}\")"
]
},
{
"cell_type": "markdown",
"id": "7f01ed5d",
"metadata": {},
"source": [
"## Conclusion\n",
"1. TensorNEAT use functional programming paradiam.\n",
"2. TensorNEAT provides `State` to manage data.\n",
"3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,21 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7c6d7313",
"metadata": {},
"source": [
"# Tutorial 1: Genome\n",
"The genome is the core component of TensorNEAT. It represents the networks genotype."
]
},
{
"cell_type": "markdown",
"id": "939c40b4",
"metadata": {},
"source": [
"\n",
"Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
"execution_count": 66,
"id": "3b1836c3",
"metadata": {},
"outputs": [],
"source": [
"from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
"\n",
"genome = DefaultGenome(\n",
" num_inputs=3, # 3 inputs\n",
" num_outputs=1, # 1 output\n",
" max_nodes=5, # the network will have at most 5 nodes\n",
" max_conns=10, # the network will have at most 10 connections\n",
" node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
" conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
" mutation=DefaultMutation(\n",
" node_add=0.9,\n",
" node_delete=0.0,\n",
" conn_add=0.9,\n",
" conn_delete=0.0\n",
" ) # high mutation rate for testing\n",
")\n",
"\n",
"state = genome.setup()"
]
},
{
"cell_type": "markdown",
"id": "f64c565c",
"metadata": {},
"source": [
"After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "ef66ba76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
"batch_output=Array([[ 2.5477247 ],\n",
" [ 2.2334106 ],\n",
" [ 1.8713341 ],\n",
" [-3.7539673 ],\n",
" [ 1.5344429 ],\n",
" [ 2.7640016 ],\n",
" [ 0.5649997 ],\n",
" [-0.32709932],\n",
" [ 3.5273829 ],\n",
" [-0.64774114]], dtype=float32, weak_type=True)\n"
]
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"# Initialize a network\n",
"nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
"\n",
"# Network forward\n",
"single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
"transformed = genome.transform(state, nodes, conns)\n",
"output = genome.forward(state, transformed, single_input)\n",
"print(f\"{output=}\")\n",
"\n",
"# Network batch forward\n",
"batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
"batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
"print(f\"{batch_output=}\")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "85433b5e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pop_outputs.shape=(1000, 1)\n"
]
}
],
"source": [
"# Initialize a population of networks (1000 networks)\n",
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
" state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
")\n",
"\n",
"# Population forward\n",
"pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
"pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
" state, pop_nodes, pop_conns\n",
")\n",
"pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
" state, pop_transformed, pop_inputs\n",
")\n",
"\n",
"print(f\"{pop_outputs.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "f1f9df60",
"metadata": {},
"outputs": [],
"source": [
"# visualize the network\n",
"network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
"genome.visualize(network, save_path=\"./origin_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "da5ad8ae",
"metadata": {},
"source": [
"![origin_network](./origin_network.svg)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "16966b69",
"metadata": {},
"outputs": [],
"source": [
"# Mutate the network\n",
"mutated_nodes, mutated_conns = genome.execute_mutation(\n",
" state,\n",
" jax.random.PRNGKey(2),\n",
" nodes,\n",
" conns,\n",
" new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
" new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
")\n",
"\n",
"# Visualize the mutated network\n",
"mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
"genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "c4367424",
"metadata": {},
"source": [
"![mutated_networ](./mutated_network.svg)"
]
},
{
"cell_type": "markdown",
"id": "f9922d34",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "",
"name": ""
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"name": ""
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,528 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial 2: Usage\n",
"\n",
"This tutorial introduces how to use TensorNEAT to solve problems. \n",
"\n",
"TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n",
"Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n",
"The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n",
"\n",
"---\n",
"\n",
"## Types of Problems in TensorNEAT \n",
"\n",
"The problems to be solved using TensorNEAT can be categorized into the following cases:\n",
"\n",
"1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n",
" - In this case, users can directly create a pipeline and execute it.\n",
"\n",
"2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n",
" - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n",
"\n",
"3. **Problems not provided by TensorNEAT and not JIT-compatible** \n",
" - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n",
" - The detailed method for manual execution is explained below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.1 Using Existing Benchmarks\n",
"\n",
"TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n",
"\n",
"If you want to use these predefined problems, refer to the **examples** for implementation details."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.2 Custom Jitable Problem\n",
"The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n"
]
}
],
"source": [
"# Prepartion\n",
"import jax, jax.numpy as jnp\n",
"from tensorneat.problem import BaseProblem\n",
"\n",
"# The problem is to fit pagie_polynomial\n",
"def pagie_polynomial(inputs):\n",
" x, y = inputs\n",
" res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n",
"\n",
" # Important! Returns an array with one item, NOT a scalar\n",
" return jnp.array([res])\n",
"\n",
"# Create dataset (10 samples)\n",
"INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n",
"LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n",
"\n",
"print(f\"{INPUTS.shape=}, {LABELS.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"# Define the custom Problem\n",
"class CustomProblem(BaseProblem):\n",
"\n",
" jitable = True # necessary\n",
"\n",
" def evaluate(self, state, randkey, act_func, params):\n",
" # Use ``act_func(state, params, inputs)`` to do network forward\n",
"\n",
" # do batch forward for all inputs (using jax.vamp)\n",
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
" state, params, INPUTS\n",
" ) # should be shape (1000, 1)\n",
"\n",
" # calculate loss\n",
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
"\n",
" # return negative loss as fitness \n",
" # TensorNEAT maximizes fitness, equivalent to minimizes loss\n",
" return -loss\n",
"\n",
" @property\n",
" def input_shape(self):\n",
" # the input shape that the act_func expects\n",
" return (2, )\n",
" \n",
" @property\n",
" def output_shape(self):\n",
" # the output shape that the act_func returns\n",
" return (1, )\n",
" \n",
" def show(self, state, randkey, act_func, params, *args, **kwargs):\n",
" # shocase the performance of one individual\n",
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
" state, params, INPUTS\n",
" )\n",
"\n",
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
"\n",
" msg = \"\"\n",
" for i in range(INPUTS.shape[0]):\n",
" msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n",
" msg += f\"loss: {loss}\\n\"\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initializing\n",
"initializing finished\n",
"start compile\n",
"compile finished, cost time: 10.318278s\n",
"Generation: 1, Cost time: 45.01ms\n",
" \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
"\n",
"\tnode counts: max: 4, min: 3, mean: 3.10\n",
" \tconn counts: max: 3, min: 0, mean: 1.88\n",
" \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
"\n",
"Generation: 2, Cost time: 53.77ms\n",
" \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
"\n",
"\tnode counts: max: 5, min: 3, mean: 3.17\n",
" \tconn counts: max: 4, min: 0, mean: 1.87\n",
" \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
"\n",
"Generation: 3, Cost time: 21.86ms\n",
" \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
"\n",
"\tnode counts: max: 6, min: 3, mean: 3.49\n",
" \tconn counts: max: 6, min: 0, mean: 2.47\n",
" \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
"\n",
"Generation: 4, Cost time: 24.30ms\n",
" \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
"\n",
"\tnode counts: max: 6, min: 3, mean: 3.76\n",
" \tconn counts: max: 6, min: 0, mean: 2.66\n",
" \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
"\n",
"Generation: 5, Cost time: 27.68ms\n",
" \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
"\n",
"\tnode counts: max: 6, min: 3, mean: 3.94\n",
" \tconn counts: max: 6, min: 0, mean: 2.80\n",
" \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n",
"\n",
"Generation limit reached!\n",
"input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n",
"input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n",
"input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n",
"input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n",
"input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n",
"input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n",
"input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n",
"input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n",
"input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n",
"input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n",
"loss: 0.0055083888582885265\n",
"\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"\n",
"from tensorneat.pipeline import Pipeline\n",
"from tensorneat.algorithm.neat import NEAT\n",
"from tensorneat.genome import DefaultGenome, BiasNode\n",
"from tensorneat.problem.func_fit import CustomFuncFit\n",
"from tensorneat.common import ACT, AGG\n",
"\n",
"# Construct the pipeline and run\n",
"pipeline = Pipeline(\n",
" algorithm=NEAT(\n",
" pop_size=1000,\n",
" species_size=20,\n",
" survival_threshold=0.01,\n",
" genome=DefaultGenome(\n",
" num_inputs=2,\n",
" num_outputs=1,\n",
" init_hidden_layers=(),\n",
" node_gene=BiasNode(\n",
" activation_options=[ACT.identity, ACT.inv],\n",
" aggregation_options=[AGG.sum, AGG.product],\n",
" ),\n",
" output_transform=ACT.identity,\n",
" ),\n",
" ),\n",
" problem=CustomProblem(),\n",
" generation_limit=5,\n",
" fitness_target=-1e-4,\n",
" seed=42,\n",
")\n",
"\n",
"# initialize state\n",
"state = pipeline.setup()\n",
"# run until terminate\n",
"state, best = pipeline.auto_run(state)\n",
"# show result\n",
"pipeline.show(state, best)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 Custom Un-Jitable Problem \n",
"This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem."
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"# We use Cartpole in gymnasium as the Un-jitable problem\n",
"import gymnasium as gym\n",
"env = gym.make(\"CartPole-v1\")\n",
"\n",
"from tensorneat.common import State\n",
"# Define the genome and Pre jit necessary functions in genome\n",
"genome=DefaultGenome(\n",
" num_inputs=4,\n",
" num_outputs=2,\n",
" init_hidden_layers=(),\n",
" node_gene=BiasNode(),\n",
" output_transform=jnp.argmax,\n",
")\n",
"state = State(randkey=jax.random.key(0))\n",
"state = genome.setup(state)\n",
"\n",
"jit_transform = jax.jit(genome.transform)\n",
"jit_forward = jax.jit(genome.forward)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"# define the method to evaluate the individual and the population\n",
"from tqdm import tqdm\n",
"\n",
"def evaluate(state, nodes, conns):\n",
" # evaluate the individual\n",
" transformed = jit_transform(state, nodes, conns)\n",
"\n",
" observation, info = env.reset()\n",
" episode_over, total_reward = False, 0\n",
" while not episode_over:\n",
" action = jit_forward(state, transformed, observation)\n",
" # currently the action is a jax array on gpu\n",
" # we need move it to cpu for env step\n",
" action = jax.device_get(action)\n",
"\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" total_reward += reward\n",
" episode_over = terminated or truncated\n",
"\n",
" return total_reward\n",
"\n",
"def evaluate_population(state, pop_nodes, pop_conns):\n",
" # evaluate the population\n",
" pop_size = pop_nodes.shape[0]\n",
" fitness = []\n",
" for i in tqdm(range(pop_size)):\n",
" fitness.append(\n",
" evaluate(state, pop_nodes[i], pop_conns[i])\n",
" )\n",
"\n",
" # return a jax array\n",
" return jnp.asarray(fitness)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"# define the algorithm\n",
"algorithm = NEAT(\n",
" pop_size=100,\n",
" species_size=20,\n",
" survival_threshold=0.1,\n",
" genome=genome,\n",
")\n",
"state = algorithm.setup(state)\n",
"\n",
"# jit for acceleration\n",
"jit_algorithm_ask = jax.jit(algorithm.ask)\n",
"jit_algorithm_tell = jax.jit(algorithm.tell)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start running...\n",
"Generation 0: evaluating population...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 0: best fitness: 384.0\n",
"Generation 1: evaluating population...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1: best fitness: 500.0\n",
"Fitness limit reached!\n"
]
}
],
"source": [
"# run!\n",
"print(\"Start running...\")\n",
"for generation in range(10):\n",
" pop_nodes, pop_conns = jit_algorithm_ask(state)\n",
" print(f\"Generation {generation}: evaluating population...\")\n",
" fitness = evaluate_population(state, pop_nodes, pop_conns)\n",
"\n",
" state = jit_algorithm_tell(state, fitness)\n",
" print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n",
"\n",
" if fitness.max() >= 500:\n",
" print(\"Fitness limit reached!\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above code runs slowly due to the following reasons:\n",
"1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n",
"2. We do not take advantage of TensorNEATs GPU parallel execution capabilities.\n",
"3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n",
"\n",
"\n",
"The following code demonstrates an optimized `gymnasium` evaluation process:"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"# use numpy as numpy-python switch takes shorter time than jax-python switch\n",
"import numpy as np\n",
"\n",
"jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n",
"jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n",
"\n",
"POP_SIZE = 100\n",
"# Use multiple envs\n",
"envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n",
"def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n",
" # transformed the population using batch transfrom\n",
" pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n",
"\n",
" pop_observation = [env.reset()[0] for env in envs]\n",
" pop_observation = np.asarray(pop_observation)\n",
" pop_fitness = np.zeros(POP_SIZE)\n",
" episode_over = np.zeros(POP_SIZE, dtype=bool)\n",
" \n",
" while not np.all(episode_over):\n",
" # batch forward\n",
" pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n",
" pop_action = jax.device_get(pop_action)\n",
"\n",
" obs, reward, terminated, truncated = [], [], [], []\n",
" # we still need to step the envs one by one\n",
" for i in range(POP_SIZE):\n",
" obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
" obs.append(obs_)\n",
" reward.append(reward_)\n",
" terminated.append(terminated_)\n",
" truncated.append(truncated_)\n",
"\n",
" pop_observation = np.asarray(obs)\n",
" pop_reward = np.asarray(reward)\n",
" pop_terminated = np.asarray(terminated)\n",
" pop_truncated = np.asarray(truncated)\n",
"\n",
" # update fitness and over\n",
" pop_fitness += pop_reward * ~episode_over\n",
" episode_over = episode_over | pop_terminated | pop_truncated\n",
"\n",
" return pop_fitness"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n",
" logger.warn(\n",
"100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"slow_time=14.704506635665894, fast_time=1.5114572048187256\n"
]
},
{
"data": {
"text/plain": [
"(14.704506635665894, 1.5114572048187256)"
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Compare the speed between these two methods\n",
"# prerun once for jax compile\n",
"accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
"\n",
"import time\n",
"time_tic = time.time()\n",
"fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n",
"slow_time = time.time() - time_tic\n",
"\n",
"time_tic = time.time()\n",
"fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
"fast_time = time.time() - time_tic\n",
"\n",
"print(f\"{slow_time=}, {fast_time=}\")\n",
"slow_time, fast_time"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.