diff --git a/.gitignore b/.gitignore
index a7dbb1b..30e1f6a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -114,4 +114,6 @@ cython_debug/
# Other
*.log
*.pot
-*.mo
\ No newline at end of file
+*.mo
+
+tutorials/.ipynb_checkpoints/*
\ No newline at end of file
diff --git a/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb b/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb
deleted file mode 100644
index 363fcab..0000000
--- a/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb
+++ /dev/null
@@ -1,6 +0,0 @@
-{
- "cells": [],
- "metadata": {},
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/tutorials/mutated_network.svg b/tutorials/mutated_network.svg
new file mode 100644
index 0000000..d0a1a5b
--- /dev/null
+++ b/tutorials/mutated_network.svg
@@ -0,0 +1,132 @@
+
+
+
diff --git a/tutorials/origin_network.svg b/tutorials/origin_network.svg
new file mode 100644
index 0000000..bde4a7a
--- /dev/null
+++ b/tutorials/origin_network.svg
@@ -0,0 +1,112 @@
+
+
+
diff --git a/tutorials/tutorial-00-functional programming and state.ipynb b/tutorials/tutorial-00-functional programming and state.ipynb
new file mode 100644
index 0000000..e0c2b2f
--- /dev/null
+++ b/tutorials/tutorial-00-functional programming and state.ipynb
@@ -0,0 +1,334 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "3b568ac2",
+ "metadata": {},
+ "source": [
+ "# Tutorial 0: Functional Programming and State"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "302f305f",
+ "metadata": {},
+ "source": [
+ "## Functional Programming\n",
+ "\n",
+ "TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n",
+ "\n",
+ "Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n",
+ "\n",
+ "1. **Pure Functions**: The same input always produces the same output, with no side effects.\n",
+ "2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n",
+ "3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ee0c8749",
+ "metadata": {},
+ "source": [
+ "## State\n",
+ "\n",
+ "In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n",
+ "\n",
+ "Here are some usages about `State`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "557a2f24",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "state_a=State ({})\n",
+ "state_b=State ({'a': 1, 'b': 2})\n",
+ "state_b.a=1\n",
+ "state_b.b=2\n"
+ ]
+ }
+ ],
+ "source": [
+ "# import State\n",
+ "from tensorneat.common import State\n",
+ "\n",
+ "# create a new state\n",
+ "state_a = State() # no arguments\n",
+ "state_b = State(a=1, b=2) # kwargs\n",
+ "\n",
+ "print(f\"{state_a=}\")\n",
+ "print(f\"{state_b=}\")\n",
+ "\n",
+ "# get items from state, use dot notation\n",
+ "print(f\"{state_b.a=}\")\n",
+ "print(f\"{state_b.b=}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "2fd169ea",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "state_a=State ({'a': 1, 'b': 2})\n"
+ ]
+ }
+ ],
+ "source": [
+ "# add new items to the state, use register\n",
+ "state_a = state_a.register(a=1, b=2)\n",
+ "print(f\"{state_a=}\")\n",
+ "\n",
+ "# We CANNOT register the existing item\n",
+ "# state_a = state_a.register(a=1)\n",
+ "# will raise ValueError(f\"Key {key} already exists in state\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "8fe0d395",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "state_a=State ({'a': 3, 'b': 4})\n"
+ ]
+ }
+ ],
+ "source": [
+ "# update the value of an item, use update\n",
+ "state_a = state_a.update(a=3, b=4)\n",
+ "print(f\"{state_a=}\")\n",
+ "\n",
+ "# We CANNOT update the non-existing item\n",
+ "# state_a = state_a.update(c=3)\n",
+ "# will raise ValueError(f\"Key {key} does not exist in state\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "b3ed0eed",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "origin_state=State ({'a': 1, 'b': 2})\n",
+ "new_state=State ({'a': 3, 'b': 2})\n"
+ ]
+ }
+ ],
+ "source": [
+ "# State is immutable! We always create a new state, rather than modifying the existing one.\n",
+ "\n",
+ "origin_state = State(a=1, b=2)\n",
+ "new_state = origin_state.update(a=3)\n",
+ "print(f\"{origin_state=}\") # origin_state is not changed\n",
+ "print(f\"{new_state=}\")\n",
+ "\n",
+ "# We can not modify the state directly\n",
+ "# origin_state.a = 3\n",
+ "# will raise AttributeError: AttributeError(\"State is immutable\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "8c73e60c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n"
+ ]
+ }
+ ],
+ "source": [
+ "# State can be used in JAX functions\n",
+ "import jax\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "def func(state):\n",
+ " c = state.a + state.b # fetch items from state\n",
+ " state = state.update(a=c, b=c) # update items in state\n",
+ " state = state.register(c=c) # add new item to state\n",
+ " return state # return state\n",
+ "\n",
+ "\n",
+ "new_state = func(state_a)\n",
+ "print(f\"{new_state=}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "2bd732ca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Save the state (use pickle) as file and load it.\n",
+ "state = State(a=1, b=2, c=3, d=4)\n",
+ "state.save(\"tutorial_0_state.pkl\")\n",
+ "loaded_state = State.load(\"tutorial_0_state.pkl\")\n",
+ "print(f\"{loaded_state=}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "846b5374",
+ "metadata": {},
+ "source": [
+ "## Objects in TensorNEAT\n",
+ "\n",
+ "In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n",
+ "\n",
+ "In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n",
+ "\n",
+ "For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "dac2bee6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tensorneat.genome import DefaultGenome\n",
+ "\n",
+ "genome = DefaultGenome(\n",
+ " num_inputs=3,\n",
+ " num_outputs=1,\n",
+ " max_nodes=5,\n",
+ " max_conns=5,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c196cc4",
+ "metadata": {},
+ "source": [
+ "`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n",
+ "\n",
+ "To create a genome that can participate in calculation, we need to do following things."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "7c597c04",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n",
+ " [ 1. , 0.9807121 , 1. , 0. , 0. ],\n",
+ " [ 2. , -0.8425486 , 1. , 0. , 0. ],\n",
+ " [ 3. , -0.53765106, 1. , 0. , 0. ],\n",
+ " [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n",
+ "conns=Array([[0. , 3. , 0.785558 ],\n",
+ " [1. , 3. , 2.3734226 ],\n",
+ " [2. , 3. , 0.07902155],\n",
+ " [ nan, nan, nan],\n",
+ " [ nan, nan, nan]], dtype=float32, weak_type=True)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# setup the genome, let the genome class store some useful information in State\n",
+ "state = genome.setup()\n",
+ "\n",
+ "# create a new genome\n",
+ "randkey = jax.random.key(0)\n",
+ "nodes, conns = genome.initialize(state, randkey)\n",
+ "print(f\"{nodes=}\")\n",
+ "print(f\"{conns=}\")\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "1d8ae137",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "outputs=Array([5.231817], dtype=float32, weak_type=True)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# calculate\n",
+ "inputs = jax.numpy.array([1, 2, 3])\n",
+ "\n",
+ "transformed = genome.transform(state, nodes, conns)\n",
+ "outputs = genome.forward(state, transformed, inputs)\n",
+ "\n",
+ "print(f\"{outputs=}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f01ed5d",
+ "metadata": {},
+ "source": [
+ "## Conclusion\n",
+ "1. TensorNEAT use functional programming paradiam.\n",
+ "2. TensorNEAT provides `State` to manage data.\n",
+ "3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "jax_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/tutorial-01-genome.ipynb b/tutorials/tutorial-01-genome.ipynb
index d0dbbca..33a0079 100644
--- a/tutorials/tutorial-01-genome.ipynb
+++ b/tutorials/tutorial-01-genome.ipynb
@@ -1,21 +1,206 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "id": "7c6d7313",
+ "metadata": {},
+ "source": [
+ "# Tutorial 1: Genome\n",
+ "The genome is the core component of TensorNEAT. It represents the network’s genotype."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "939c40b4",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": null,
- "id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
+ "execution_count": 66,
+ "id": "3b1836c3",
"metadata": {},
"outputs": [],
+ "source": [
+ "from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
+ "\n",
+ "genome = DefaultGenome(\n",
+ " num_inputs=3, # 3 inputs\n",
+ " num_outputs=1, # 1 output\n",
+ " max_nodes=5, # the network will have at most 5 nodes\n",
+ " max_conns=10, # the network will have at most 10 connections\n",
+ " node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
+ " conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
+ " mutation=DefaultMutation(\n",
+ " node_add=0.9,\n",
+ " node_delete=0.0,\n",
+ " conn_add=0.9,\n",
+ " conn_delete=0.0\n",
+ " ) # high mutation rate for testing\n",
+ ")\n",
+ "\n",
+ "state = genome.setup()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f64c565c",
+ "metadata": {},
+ "source": [
+ "After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "id": "ef66ba76",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
+ "batch_output=Array([[ 2.5477247 ],\n",
+ " [ 2.2334106 ],\n",
+ " [ 1.8713341 ],\n",
+ " [-3.7539673 ],\n",
+ " [ 1.5344429 ],\n",
+ " [ 2.7640016 ],\n",
+ " [ 0.5649997 ],\n",
+ " [-0.32709932],\n",
+ " [ 3.5273829 ],\n",
+ " [-0.64774114]], dtype=float32, weak_type=True)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import jax, jax.numpy as jnp\n",
+ "\n",
+ "# Initialize a network\n",
+ "nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
+ "\n",
+ "# Network forward\n",
+ "single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
+ "transformed = genome.transform(state, nodes, conns)\n",
+ "output = genome.forward(state, transformed, single_input)\n",
+ "print(f\"{output=}\")\n",
+ "\n",
+ "# Network batch forward\n",
+ "batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
+ "batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
+ "print(f\"{batch_output=}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 68,
+ "id": "85433b5e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pop_outputs.shape=(1000, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialize a population of networks (1000 networks)\n",
+ "pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
+ " state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
+ ")\n",
+ "\n",
+ "# Population forward\n",
+ "pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
+ "pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
+ " state, pop_nodes, pop_conns\n",
+ ")\n",
+ "pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
+ " state, pop_transformed, pop_inputs\n",
+ ")\n",
+ "\n",
+ "print(f\"{pop_outputs.shape=}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "f1f9df60",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# visualize the network\n",
+ "network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
+ "genome.visualize(network, save_path=\"./origin_network.svg\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "da5ad8ae",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "id": "16966b69",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Mutate the network\n",
+ "mutated_nodes, mutated_conns = genome.execute_mutation(\n",
+ " state,\n",
+ " jax.random.PRNGKey(2),\n",
+ " nodes,\n",
+ " conns,\n",
+ " new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
+ " new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
+ ")\n",
+ "\n",
+ "# Visualize the mutated network\n",
+ "mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
+ "genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c4367424",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f9922d34",
+ "metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
- "display_name": "",
- "name": ""
+ "display_name": "jax_env",
+ "language": "python",
+ "name": "python3"
},
"language_info": {
- "name": ""
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
}
},
"nbformat": 4,
diff --git a/tutorials/tutorial-02-Usage.ipynb b/tutorials/tutorial-02-Usage.ipynb
new file mode 100644
index 0000000..a35d497
--- /dev/null
+++ b/tutorials/tutorial-02-Usage.ipynb
@@ -0,0 +1,528 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 2: Usage\n",
+ "\n",
+ "This tutorial introduces how to use TensorNEAT to solve problems. \n",
+ "\n",
+ "TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n",
+ "Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n",
+ "The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "## Types of Problems in TensorNEAT \n",
+ "\n",
+ "The problems to be solved using TensorNEAT can be categorized into the following cases:\n",
+ "\n",
+ "1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n",
+ " - In this case, users can directly create a pipeline and execute it.\n",
+ "\n",
+ "2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n",
+ " - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n",
+ "\n",
+ "3. **Problems not provided by TensorNEAT and not JIT-compatible** \n",
+ " - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n",
+ " - The detailed method for manual execution is explained below."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2.1 Using Existing Benchmarks\n",
+ "\n",
+ "TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n",
+ "\n",
+ "If you want to use these predefined problems, refer to the **examples** for implementation details."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2.2 Custom Jitable Problem\n",
+ "The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Prepartion\n",
+ "import jax, jax.numpy as jnp\n",
+ "from tensorneat.problem import BaseProblem\n",
+ "\n",
+ "# The problem is to fit pagie_polynomial\n",
+ "def pagie_polynomial(inputs):\n",
+ " x, y = inputs\n",
+ " res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n",
+ "\n",
+ " # Important! Returns an array with one item, NOT a scalar\n",
+ " return jnp.array([res])\n",
+ "\n",
+ "# Create dataset (10 samples)\n",
+ "INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n",
+ "LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n",
+ "\n",
+ "print(f\"{INPUTS.shape=}, {LABELS.shape=}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define the custom Problem\n",
+ "class CustomProblem(BaseProblem):\n",
+ "\n",
+ " jitable = True # necessary\n",
+ "\n",
+ " def evaluate(self, state, randkey, act_func, params):\n",
+ " # Use ``act_func(state, params, inputs)`` to do network forward\n",
+ "\n",
+ " # do batch forward for all inputs (using jax.vamp)\n",
+ " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
+ " state, params, INPUTS\n",
+ " ) # should be shape (1000, 1)\n",
+ "\n",
+ " # calculate loss\n",
+ " loss = jnp.mean(jnp.square(predict - LABELS))\n",
+ "\n",
+ " # return negative loss as fitness \n",
+ " # TensorNEAT maximizes fitness, equivalent to minimizes loss\n",
+ " return -loss\n",
+ "\n",
+ " @property\n",
+ " def input_shape(self):\n",
+ " # the input shape that the act_func expects\n",
+ " return (2, )\n",
+ " \n",
+ " @property\n",
+ " def output_shape(self):\n",
+ " # the output shape that the act_func returns\n",
+ " return (1, )\n",
+ " \n",
+ " def show(self, state, randkey, act_func, params, *args, **kwargs):\n",
+ " # shocase the performance of one individual\n",
+ " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
+ " state, params, INPUTS\n",
+ " )\n",
+ "\n",
+ " loss = jnp.mean(jnp.square(predict - LABELS))\n",
+ "\n",
+ " msg = \"\"\n",
+ " for i in range(INPUTS.shape[0]):\n",
+ " msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n",
+ " msg += f\"loss: {loss}\\n\"\n",
+ " print(msg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "initializing\n",
+ "initializing finished\n",
+ "start compile\n",
+ "compile finished, cost time: 10.318278s\n",
+ "Generation: 1, Cost time: 45.01ms\n",
+ " \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
+ "\n",
+ "\tnode counts: max: 4, min: 3, mean: 3.10\n",
+ " \tconn counts: max: 3, min: 0, mean: 1.88\n",
+ " \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
+ "\n",
+ "Generation: 2, Cost time: 53.77ms\n",
+ " \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
+ "\n",
+ "\tnode counts: max: 5, min: 3, mean: 3.17\n",
+ " \tconn counts: max: 4, min: 0, mean: 1.87\n",
+ " \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
+ "\n",
+ "Generation: 3, Cost time: 21.86ms\n",
+ " \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
+ "\n",
+ "\tnode counts: max: 6, min: 3, mean: 3.49\n",
+ " \tconn counts: max: 6, min: 0, mean: 2.47\n",
+ " \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
+ "\n",
+ "Generation: 4, Cost time: 24.30ms\n",
+ " \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
+ "\n",
+ "\tnode counts: max: 6, min: 3, mean: 3.76\n",
+ " \tconn counts: max: 6, min: 0, mean: 2.66\n",
+ " \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
+ "\n",
+ "Generation: 5, Cost time: 27.68ms\n",
+ " \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
+ "\n",
+ "\tnode counts: max: 6, min: 3, mean: 3.94\n",
+ " \tconn counts: max: 6, min: 0, mean: 2.80\n",
+ " \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n",
+ "\n",
+ "Generation limit reached!\n",
+ "input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n",
+ "input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n",
+ "input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n",
+ "input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n",
+ "input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n",
+ "input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n",
+ "input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n",
+ "input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n",
+ "input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n",
+ "input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n",
+ "loss: 0.0055083888582885265\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import jax.numpy as jnp\n",
+ "\n",
+ "from tensorneat.pipeline import Pipeline\n",
+ "from tensorneat.algorithm.neat import NEAT\n",
+ "from tensorneat.genome import DefaultGenome, BiasNode\n",
+ "from tensorneat.problem.func_fit import CustomFuncFit\n",
+ "from tensorneat.common import ACT, AGG\n",
+ "\n",
+ "# Construct the pipeline and run\n",
+ "pipeline = Pipeline(\n",
+ " algorithm=NEAT(\n",
+ " pop_size=1000,\n",
+ " species_size=20,\n",
+ " survival_threshold=0.01,\n",
+ " genome=DefaultGenome(\n",
+ " num_inputs=2,\n",
+ " num_outputs=1,\n",
+ " init_hidden_layers=(),\n",
+ " node_gene=BiasNode(\n",
+ " activation_options=[ACT.identity, ACT.inv],\n",
+ " aggregation_options=[AGG.sum, AGG.product],\n",
+ " ),\n",
+ " output_transform=ACT.identity,\n",
+ " ),\n",
+ " ),\n",
+ " problem=CustomProblem(),\n",
+ " generation_limit=5,\n",
+ " fitness_target=-1e-4,\n",
+ " seed=42,\n",
+ ")\n",
+ "\n",
+ "# initialize state\n",
+ "state = pipeline.setup()\n",
+ "# run until terminate\n",
+ "state, best = pipeline.auto_run(state)\n",
+ "# show result\n",
+ "pipeline.show(state, best)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.3 Custom Un-Jitable Problem \n",
+ "This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We use Cartpole in gymnasium as the Un-jitable problem\n",
+ "import gymnasium as gym\n",
+ "env = gym.make(\"CartPole-v1\")\n",
+ "\n",
+ "from tensorneat.common import State\n",
+ "# Define the genome and Pre jit necessary functions in genome\n",
+ "genome=DefaultGenome(\n",
+ " num_inputs=4,\n",
+ " num_outputs=2,\n",
+ " init_hidden_layers=(),\n",
+ " node_gene=BiasNode(),\n",
+ " output_transform=jnp.argmax,\n",
+ ")\n",
+ "state = State(randkey=jax.random.key(0))\n",
+ "state = genome.setup(state)\n",
+ "\n",
+ "jit_transform = jax.jit(genome.transform)\n",
+ "jit_forward = jax.jit(genome.forward)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the method to evaluate the individual and the population\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "def evaluate(state, nodes, conns):\n",
+ " # evaluate the individual\n",
+ " transformed = jit_transform(state, nodes, conns)\n",
+ "\n",
+ " observation, info = env.reset()\n",
+ " episode_over, total_reward = False, 0\n",
+ " while not episode_over:\n",
+ " action = jit_forward(state, transformed, observation)\n",
+ " # currently the action is a jax array on gpu\n",
+ " # we need move it to cpu for env step\n",
+ " action = jax.device_get(action)\n",
+ "\n",
+ " observation, reward, terminated, truncated, info = env.step(action)\n",
+ " total_reward += reward\n",
+ " episode_over = terminated or truncated\n",
+ "\n",
+ " return total_reward\n",
+ "\n",
+ "def evaluate_population(state, pop_nodes, pop_conns):\n",
+ " # evaluate the population\n",
+ " pop_size = pop_nodes.shape[0]\n",
+ " fitness = []\n",
+ " for i in tqdm(range(pop_size)):\n",
+ " fitness.append(\n",
+ " evaluate(state, pop_nodes[i], pop_conns[i])\n",
+ " )\n",
+ "\n",
+ " # return a jax array\n",
+ " return jnp.asarray(fitness)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 95,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the algorithm\n",
+ "algorithm = NEAT(\n",
+ " pop_size=100,\n",
+ " species_size=20,\n",
+ " survival_threshold=0.1,\n",
+ " genome=genome,\n",
+ ")\n",
+ "state = algorithm.setup(state)\n",
+ "\n",
+ "# jit for acceleration\n",
+ "jit_algorithm_ask = jax.jit(algorithm.ask)\n",
+ "jit_algorithm_tell = jax.jit(algorithm.tell)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Start running...\n",
+ "Generation 0: evaluating population...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generation 0: best fitness: 384.0\n",
+ "Generation 1: evaluating population...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generation 1: best fitness: 500.0\n",
+ "Fitness limit reached!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# run!\n",
+ "print(\"Start running...\")\n",
+ "for generation in range(10):\n",
+ " pop_nodes, pop_conns = jit_algorithm_ask(state)\n",
+ " print(f\"Generation {generation}: evaluating population...\")\n",
+ " fitness = evaluate_population(state, pop_nodes, pop_conns)\n",
+ "\n",
+ " state = jit_algorithm_tell(state, fitness)\n",
+ " print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n",
+ "\n",
+ " if fitness.max() >= 500:\n",
+ " print(\"Fitness limit reached!\")\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The above code runs slowly due to the following reasons:\n",
+ "1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n",
+ "2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.\n",
+ "3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n",
+ "\n",
+ "\n",
+ "The following code demonstrates an optimized `gymnasium` evaluation process:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# use numpy as numpy-python switch takes shorter time than jax-python switch\n",
+ "import numpy as np\n",
+ "\n",
+ "jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n",
+ "jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n",
+ "\n",
+ "POP_SIZE = 100\n",
+ "# Use multiple envs\n",
+ "envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n",
+ "def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n",
+ " # transformed the population using batch transfrom\n",
+ " pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n",
+ "\n",
+ " pop_observation = [env.reset()[0] for env in envs]\n",
+ " pop_observation = np.asarray(pop_observation)\n",
+ " pop_fitness = np.zeros(POP_SIZE)\n",
+ " episode_over = np.zeros(POP_SIZE, dtype=bool)\n",
+ " \n",
+ " while not np.all(episode_over):\n",
+ " # batch forward\n",
+ " pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n",
+ " pop_action = jax.device_get(pop_action)\n",
+ "\n",
+ " obs, reward, terminated, truncated = [], [], [], []\n",
+ " # we still need to step the envs one by one\n",
+ " for i in range(POP_SIZE):\n",
+ " obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
+ " obs.append(obs_)\n",
+ " reward.append(reward_)\n",
+ " terminated.append(terminated_)\n",
+ " truncated.append(truncated_)\n",
+ "\n",
+ " pop_observation = np.asarray(obs)\n",
+ " pop_reward = np.asarray(reward)\n",
+ " pop_terminated = np.asarray(terminated)\n",
+ " pop_truncated = np.asarray(truncated)\n",
+ "\n",
+ " # update fitness and over\n",
+ " pop_fitness += pop_reward * ~episode_over\n",
+ " episode_over = episode_over | pop_terminated | pop_truncated\n",
+ "\n",
+ " return pop_fitness"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n",
+ " logger.warn(\n",
+ "100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "slow_time=14.704506635665894, fast_time=1.5114572048187256\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(14.704506635665894, 1.5114572048187256)"
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Compare the speed between these two methods\n",
+ "# prerun once for jax compile\n",
+ "accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
+ "\n",
+ "import time\n",
+ "time_tic = time.time()\n",
+ "fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n",
+ "slow_time = time.time() - time_tic\n",
+ "\n",
+ "time_tic = time.time()\n",
+ "fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
+ "fast_time = time.time() - time_tic\n",
+ "\n",
+ "print(f\"{slow_time=}, {fast_time=}\")\n",
+ "slow_time, fast_time"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "jax_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tutorials/tutorial_0_state.pkl b/tutorials/tutorial_0_state.pkl
new file mode 100644
index 0000000..818bb0d
Binary files /dev/null and b/tutorials/tutorial_0_state.pkl differ