diff --git a/.gitignore b/.gitignore index a7dbb1b..30e1f6a 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,6 @@ cython_debug/ # Other *.log *.pot -*.mo \ No newline at end of file +*.mo + +tutorials/.ipynb_checkpoints/* \ No newline at end of file diff --git a/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb b/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb deleted file mode 100644 index 363fcab..0000000 --- a/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb +++ /dev/null @@ -1,6 +0,0 @@ -{ - "cells": [], - "metadata": {}, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/mutated_network.svg b/tutorials/mutated_network.svg new file mode 100644 index 0000000..d0a1a5b --- /dev/null +++ b/tutorials/mutated_network.svg @@ -0,0 +1,132 @@ + + + + + + + + 2025-01-30T16:52:11.134853 + image/svg+xml + + + Matplotlib v3.9.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tutorials/origin_network.svg b/tutorials/origin_network.svg new file mode 100644 index 0000000..bde4a7a --- /dev/null +++ b/tutorials/origin_network.svg @@ -0,0 +1,112 @@ + + + + + + + + 2025-01-30T16:52:09.785628 + image/svg+xml + + + Matplotlib v3.9.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tutorials/tutorial-00-functional programming and state.ipynb b/tutorials/tutorial-00-functional programming and state.ipynb new file mode 100644 index 0000000..e0c2b2f --- /dev/null +++ b/tutorials/tutorial-00-functional programming and state.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3b568ac2", + "metadata": {}, + "source": [ + "# Tutorial 0: Functional Programming and State" + ] + }, + { + "cell_type": "markdown", + "id": "302f305f", + "metadata": {}, + "source": [ + "## Functional Programming\n", + "\n", + "TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n", + "\n", + "Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n", + "\n", + "1. **Pure Functions**: The same input always produces the same output, with no side effects.\n", + "2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n", + "3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values." + ] + }, + { + "cell_type": "markdown", + "id": "ee0c8749", + "metadata": {}, + "source": [ + "## State\n", + "\n", + "In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n", + "\n", + "Here are some usages about `State`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "557a2f24", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state_a=State ({})\n", + "state_b=State ({'a': 1, 'b': 2})\n", + "state_b.a=1\n", + "state_b.b=2\n" + ] + } + ], + "source": [ + "# import State\n", + "from tensorneat.common import State\n", + "\n", + "# create a new state\n", + "state_a = State() # no arguments\n", + "state_b = State(a=1, b=2) # kwargs\n", + "\n", + "print(f\"{state_a=}\")\n", + "print(f\"{state_b=}\")\n", + "\n", + "# get items from state, use dot notation\n", + "print(f\"{state_b.a=}\")\n", + "print(f\"{state_b.b=}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2fd169ea", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state_a=State ({'a': 1, 'b': 2})\n" + ] + } + ], + "source": [ + "# add new items to the state, use register\n", + "state_a = state_a.register(a=1, b=2)\n", + "print(f\"{state_a=}\")\n", + "\n", + "# We CANNOT register the existing item\n", + "# state_a = state_a.register(a=1)\n", + "# will raise ValueError(f\"Key {key} already exists in state\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8fe0d395", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state_a=State ({'a': 3, 'b': 4})\n" + ] + } + ], + "source": [ + "# update the value of an item, use update\n", + "state_a = state_a.update(a=3, b=4)\n", + "print(f\"{state_a=}\")\n", + "\n", + "# We CANNOT update the non-existing item\n", + "# state_a = state_a.update(c=3)\n", + "# will raise ValueError(f\"Key {key} does not exist in state\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b3ed0eed", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "origin_state=State ({'a': 1, 'b': 2})\n", + "new_state=State ({'a': 3, 'b': 2})\n" + ] + } + ], + "source": [ + "# State is immutable! We always create a new state, rather than modifying the existing one.\n", + "\n", + "origin_state = State(a=1, b=2)\n", + "new_state = origin_state.update(a=3)\n", + "print(f\"{origin_state=}\") # origin_state is not changed\n", + "print(f\"{new_state=}\")\n", + "\n", + "# We can not modify the state directly\n", + "# origin_state.a = 3\n", + "# will raise AttributeError: AttributeError(\"State is immutable\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8c73e60c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n" + ] + } + ], + "source": [ + "# State can be used in JAX functions\n", + "import jax\n", + "\n", + "\n", + "@jax.jit\n", + "def func(state):\n", + " c = state.a + state.b # fetch items from state\n", + " state = state.update(a=c, b=c) # update items in state\n", + " state = state.register(c=c) # add new item to state\n", + " return state # return state\n", + "\n", + "\n", + "new_state = func(state_a)\n", + "print(f\"{new_state=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2bd732ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n" + ] + } + ], + "source": [ + "# Save the state (use pickle) as file and load it.\n", + "state = State(a=1, b=2, c=3, d=4)\n", + "state.save(\"tutorial_0_state.pkl\")\n", + "loaded_state = State.load(\"tutorial_0_state.pkl\")\n", + "print(f\"{loaded_state=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "846b5374", + "metadata": {}, + "source": [ + "## Objects in TensorNEAT\n", + "\n", + "In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n", + "\n", + "In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n", + "\n", + "For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dac2bee6", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorneat.genome import DefaultGenome\n", + "\n", + "genome = DefaultGenome(\n", + " num_inputs=3,\n", + " num_outputs=1,\n", + " max_nodes=5,\n", + " max_conns=5,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2c196cc4", + "metadata": {}, + "source": [ + "`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n", + "\n", + "To create a genome that can participate in calculation, we need to do following things." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7c597c04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n", + " [ 1. , 0.9807121 , 1. , 0. , 0. ],\n", + " [ 2. , -0.8425486 , 1. , 0. , 0. ],\n", + " [ 3. , -0.53765106, 1. , 0. , 0. ],\n", + " [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n", + "conns=Array([[0. , 3. , 0.785558 ],\n", + " [1. , 3. , 2.3734226 ],\n", + " [2. , 3. , 0.07902155],\n", + " [ nan, nan, nan],\n", + " [ nan, nan, nan]], dtype=float32, weak_type=True)\n" + ] + } + ], + "source": [ + "# setup the genome, let the genome class store some useful information in State\n", + "state = genome.setup()\n", + "\n", + "# create a new genome\n", + "randkey = jax.random.key(0)\n", + "nodes, conns = genome.initialize(state, randkey)\n", + "print(f\"{nodes=}\")\n", + "print(f\"{conns=}\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1d8ae137", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "outputs=Array([5.231817], dtype=float32, weak_type=True)\n" + ] + } + ], + "source": [ + "# calculate\n", + "inputs = jax.numpy.array([1, 2, 3])\n", + "\n", + "transformed = genome.transform(state, nodes, conns)\n", + "outputs = genome.forward(state, transformed, inputs)\n", + "\n", + "print(f\"{outputs=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7f01ed5d", + "metadata": {}, + "source": [ + "## Conclusion\n", + "1. TensorNEAT use functional programming paradiam.\n", + "2. TensorNEAT provides `State` to manage data.\n", + "3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/tutorial-01-genome.ipynb b/tutorials/tutorial-01-genome.ipynb index d0dbbca..33a0079 100644 --- a/tutorials/tutorial-01-genome.ipynb +++ b/tutorials/tutorial-01-genome.ipynb @@ -1,21 +1,206 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7c6d7313", + "metadata": {}, + "source": [ + "# Tutorial 1: Genome\n", + "The genome is the core component of TensorNEAT. It represents the network’s genotype." + ] + }, + { + "cell_type": "markdown", + "id": "939c40b4", + "metadata": {}, + "source": [ + "\n", + "Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize." + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "286b5fc2-e629-4461-ad95-b69d3d606a78", + "execution_count": 66, + "id": "3b1836c3", "metadata": {}, "outputs": [], + "source": [ + "from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n", + "\n", + "genome = DefaultGenome(\n", + " num_inputs=3, # 3 inputs\n", + " num_outputs=1, # 1 output\n", + " max_nodes=5, # the network will have at most 5 nodes\n", + " max_conns=10, # the network will have at most 10 connections\n", + " node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n", + " conn_gene=DefaultConn(), # connection with 1 attribute: weight\n", + " mutation=DefaultMutation(\n", + " node_add=0.9,\n", + " node_delete=0.0,\n", + " conn_add=0.9,\n", + " conn_delete=0.0\n", + " ) # high mutation rate for testing\n", + ")\n", + "\n", + "state = genome.setup()" + ] + }, + { + "cell_type": "markdown", + "id": "f64c565c", + "metadata": {}, + "source": [ + "After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "ef66ba76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "output=Array([-1.5388116], dtype=float32, weak_type=True)\n", + "batch_output=Array([[ 2.5477247 ],\n", + " [ 2.2334106 ],\n", + " [ 1.8713341 ],\n", + " [-3.7539673 ],\n", + " [ 1.5344429 ],\n", + " [ 2.7640016 ],\n", + " [ 0.5649997 ],\n", + " [-0.32709932],\n", + " [ 3.5273829 ],\n", + " [-0.64774114]], dtype=float32, weak_type=True)\n" + ] + } + ], + "source": [ + "import jax, jax.numpy as jnp\n", + "\n", + "# Initialize a network\n", + "nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n", + "\n", + "# Network forward\n", + "single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n", + "transformed = genome.transform(state, nodes, conns)\n", + "output = genome.forward(state, transformed, single_input)\n", + "print(f\"{output=}\")\n", + "\n", + "# Network batch forward\n", + "batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n", + "batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n", + "print(f\"{batch_output=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "85433b5e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pop_outputs.shape=(1000, 1)\n" + ] + } + ], + "source": [ + "# Initialize a population of networks (1000 networks)\n", + "pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n", + " state, jax.random.split(jax.random.PRNGKey(0), 1000)\n", + ")\n", + "\n", + "# Population forward\n", + "pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n", + "pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n", + " state, pop_nodes, pop_conns\n", + ")\n", + "pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n", + " state, pop_transformed, pop_inputs\n", + ")\n", + "\n", + "print(f\"{pop_outputs.shape=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "f1f9df60", + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the network\n", + "network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n", + "genome.visualize(network, save_path=\"./origin_network.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "da5ad8ae", + "metadata": {}, + "source": [ + "![origin_network](./origin_network.svg)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "16966b69", + "metadata": {}, + "outputs": [], + "source": [ + "# Mutate the network\n", + "mutated_nodes, mutated_conns = genome.execute_mutation(\n", + " state,\n", + " jax.random.PRNGKey(2),\n", + " nodes,\n", + " conns,\n", + " new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n", + " new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n", + ")\n", + "\n", + "# Visualize the mutated network\n", + "mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n", + "genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "c4367424", + "metadata": {}, + "source": [ + "![mutated_networ](./mutated_network.svg)" + ] + }, + { + "cell_type": "markdown", + "id": "f9922d34", + "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { - "display_name": "", - "name": "" + "display_name": "jax_env", + "language": "python", + "name": "python3" }, "language_info": { - "name": "" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tutorials/tutorial-02-Usage.ipynb b/tutorials/tutorial-02-Usage.ipynb new file mode 100644 index 0000000..a35d497 --- /dev/null +++ b/tutorials/tutorial-02-Usage.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial 2: Usage\n", + "\n", + "This tutorial introduces how to use TensorNEAT to solve problems. \n", + "\n", + "TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n", + "Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n", + "The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n", + "\n", + "---\n", + "\n", + "## Types of Problems in TensorNEAT \n", + "\n", + "The problems to be solved using TensorNEAT can be categorized into the following cases:\n", + "\n", + "1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n", + " - In this case, users can directly create a pipeline and execute it.\n", + "\n", + "2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n", + " - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n", + "\n", + "3. **Problems not provided by TensorNEAT and not JIT-compatible** \n", + " - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n", + " - The detailed method for manual execution is explained below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Using Existing Benchmarks\n", + "\n", + "TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n", + "\n", + "If you want to use these predefined problems, refer to the **examples** for implementation details." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Custom Jitable Problem\n", + "The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n" + ] + } + ], + "source": [ + "# Prepartion\n", + "import jax, jax.numpy as jnp\n", + "from tensorneat.problem import BaseProblem\n", + "\n", + "# The problem is to fit pagie_polynomial\n", + "def pagie_polynomial(inputs):\n", + " x, y = inputs\n", + " res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n", + "\n", + " # Important! Returns an array with one item, NOT a scalar\n", + " return jnp.array([res])\n", + "\n", + "# Create dataset (10 samples)\n", + "INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n", + "LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n", + "\n", + "print(f\"{INPUTS.shape=}, {LABELS.shape=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the custom Problem\n", + "class CustomProblem(BaseProblem):\n", + "\n", + " jitable = True # necessary\n", + "\n", + " def evaluate(self, state, randkey, act_func, params):\n", + " # Use ``act_func(state, params, inputs)`` to do network forward\n", + "\n", + " # do batch forward for all inputs (using jax.vamp)\n", + " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n", + " state, params, INPUTS\n", + " ) # should be shape (1000, 1)\n", + "\n", + " # calculate loss\n", + " loss = jnp.mean(jnp.square(predict - LABELS))\n", + "\n", + " # return negative loss as fitness \n", + " # TensorNEAT maximizes fitness, equivalent to minimizes loss\n", + " return -loss\n", + "\n", + " @property\n", + " def input_shape(self):\n", + " # the input shape that the act_func expects\n", + " return (2, )\n", + " \n", + " @property\n", + " def output_shape(self):\n", + " # the output shape that the act_func returns\n", + " return (1, )\n", + " \n", + " def show(self, state, randkey, act_func, params, *args, **kwargs):\n", + " # shocase the performance of one individual\n", + " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n", + " state, params, INPUTS\n", + " )\n", + "\n", + " loss = jnp.mean(jnp.square(predict - LABELS))\n", + "\n", + " msg = \"\"\n", + " for i in range(INPUTS.shape[0]):\n", + " msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n", + " msg += f\"loss: {loss}\\n\"\n", + " print(msg)" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initializing\n", + "initializing finished\n", + "start compile\n", + "compile finished, cost time: 10.318278s\n", + "Generation: 1, Cost time: 45.01ms\n", + " \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n", + "\n", + "\tnode counts: max: 4, min: 3, mean: 3.10\n", + " \tconn counts: max: 3, min: 0, mean: 1.88\n", + " \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n", + "\n", + "Generation: 2, Cost time: 53.77ms\n", + " \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n", + "\n", + "\tnode counts: max: 5, min: 3, mean: 3.17\n", + " \tconn counts: max: 4, min: 0, mean: 1.87\n", + " \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n", + "\n", + "Generation: 3, Cost time: 21.86ms\n", + " \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n", + "\n", + "\tnode counts: max: 6, min: 3, mean: 3.49\n", + " \tconn counts: max: 6, min: 0, mean: 2.47\n", + " \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n", + "\n", + "Generation: 4, Cost time: 24.30ms\n", + " \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n", + "\n", + "\tnode counts: max: 6, min: 3, mean: 3.76\n", + " \tconn counts: max: 6, min: 0, mean: 2.66\n", + " \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n", + "\n", + "Generation: 5, Cost time: 27.68ms\n", + " \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n", + "\n", + "\tnode counts: max: 6, min: 3, mean: 3.94\n", + " \tconn counts: max: 6, min: 0, mean: 2.80\n", + " \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n", + "\n", + "Generation limit reached!\n", + "input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n", + "input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n", + "input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n", + "input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n", + "input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n", + "input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n", + "input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n", + "input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n", + "input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n", + "input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n", + "loss: 0.0055083888582885265\n", + "\n" + ] + } + ], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "from tensorneat.pipeline import Pipeline\n", + "from tensorneat.algorithm.neat import NEAT\n", + "from tensorneat.genome import DefaultGenome, BiasNode\n", + "from tensorneat.problem.func_fit import CustomFuncFit\n", + "from tensorneat.common import ACT, AGG\n", + "\n", + "# Construct the pipeline and run\n", + "pipeline = Pipeline(\n", + " algorithm=NEAT(\n", + " pop_size=1000,\n", + " species_size=20,\n", + " survival_threshold=0.01,\n", + " genome=DefaultGenome(\n", + " num_inputs=2,\n", + " num_outputs=1,\n", + " init_hidden_layers=(),\n", + " node_gene=BiasNode(\n", + " activation_options=[ACT.identity, ACT.inv],\n", + " aggregation_options=[AGG.sum, AGG.product],\n", + " ),\n", + " output_transform=ACT.identity,\n", + " ),\n", + " ),\n", + " problem=CustomProblem(),\n", + " generation_limit=5,\n", + " fitness_target=-1e-4,\n", + " seed=42,\n", + ")\n", + "\n", + "# initialize state\n", + "state = pipeline.setup()\n", + "# run until terminate\n", + "state, best = pipeline.auto_run(state)\n", + "# show result\n", + "pipeline.show(state, best)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 Custom Un-Jitable Problem \n", + "This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem." + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "# We use Cartpole in gymnasium as the Un-jitable problem\n", + "import gymnasium as gym\n", + "env = gym.make(\"CartPole-v1\")\n", + "\n", + "from tensorneat.common import State\n", + "# Define the genome and Pre jit necessary functions in genome\n", + "genome=DefaultGenome(\n", + " num_inputs=4,\n", + " num_outputs=2,\n", + " init_hidden_layers=(),\n", + " node_gene=BiasNode(),\n", + " output_transform=jnp.argmax,\n", + ")\n", + "state = State(randkey=jax.random.key(0))\n", + "state = genome.setup(state)\n", + "\n", + "jit_transform = jax.jit(genome.transform)\n", + "jit_forward = jax.jit(genome.forward)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "# define the method to evaluate the individual and the population\n", + "from tqdm import tqdm\n", + "\n", + "def evaluate(state, nodes, conns):\n", + " # evaluate the individual\n", + " transformed = jit_transform(state, nodes, conns)\n", + "\n", + " observation, info = env.reset()\n", + " episode_over, total_reward = False, 0\n", + " while not episode_over:\n", + " action = jit_forward(state, transformed, observation)\n", + " # currently the action is a jax array on gpu\n", + " # we need move it to cpu for env step\n", + " action = jax.device_get(action)\n", + "\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " total_reward += reward\n", + " episode_over = terminated or truncated\n", + "\n", + " return total_reward\n", + "\n", + "def evaluate_population(state, pop_nodes, pop_conns):\n", + " # evaluate the population\n", + " pop_size = pop_nodes.shape[0]\n", + " fitness = []\n", + " for i in tqdm(range(pop_size)):\n", + " fitness.append(\n", + " evaluate(state, pop_nodes[i], pop_conns[i])\n", + " )\n", + "\n", + " # return a jax array\n", + " return jnp.asarray(fitness)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "# define the algorithm\n", + "algorithm = NEAT(\n", + " pop_size=100,\n", + " species_size=20,\n", + " survival_threshold=0.1,\n", + " genome=genome,\n", + ")\n", + "state = algorithm.setup(state)\n", + "\n", + "# jit for acceleration\n", + "jit_algorithm_ask = jax.jit(algorithm.ask)\n", + "jit_algorithm_tell = jax.jit(algorithm.tell)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start running...\n", + "Generation 0: evaluating population...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation 0: best fitness: 384.0\n", + "Generation 1: evaluating population...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation 1: best fitness: 500.0\n", + "Fitness limit reached!\n" + ] + } + ], + "source": [ + "# run!\n", + "print(\"Start running...\")\n", + "for generation in range(10):\n", + " pop_nodes, pop_conns = jit_algorithm_ask(state)\n", + " print(f\"Generation {generation}: evaluating population...\")\n", + " fitness = evaluate_population(state, pop_nodes, pop_conns)\n", + "\n", + " state = jit_algorithm_tell(state, fitness)\n", + " print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n", + "\n", + " if fitness.max() >= 500:\n", + " print(\"Fitness limit reached!\")\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above code runs slowly due to the following reasons:\n", + "1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n", + "2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.\n", + "3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n", + "\n", + "\n", + "The following code demonstrates an optimized `gymnasium` evaluation process:" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "# use numpy as numpy-python switch takes shorter time than jax-python switch\n", + "import numpy as np\n", + "\n", + "jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n", + "jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n", + "\n", + "POP_SIZE = 100\n", + "# Use multiple envs\n", + "envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n", + "def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n", + " # transformed the population using batch transfrom\n", + " pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n", + "\n", + " pop_observation = [env.reset()[0] for env in envs]\n", + " pop_observation = np.asarray(pop_observation)\n", + " pop_fitness = np.zeros(POP_SIZE)\n", + " episode_over = np.zeros(POP_SIZE, dtype=bool)\n", + " \n", + " while not np.all(episode_over):\n", + " # batch forward\n", + " pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n", + " pop_action = jax.device_get(pop_action)\n", + "\n", + " obs, reward, terminated, truncated = [], [], [], []\n", + " # we still need to step the envs one by one\n", + " for i in range(POP_SIZE):\n", + " obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n", + " obs.append(obs_)\n", + " reward.append(reward_)\n", + " terminated.append(terminated_)\n", + " truncated.append(truncated_)\n", + "\n", + " pop_observation = np.asarray(obs)\n", + " pop_reward = np.asarray(reward)\n", + " pop_terminated = np.asarray(terminated)\n", + " pop_truncated = np.asarray(truncated)\n", + "\n", + " # update fitness and over\n", + " pop_fitness += pop_reward * ~episode_over\n", + " episode_over = episode_over | pop_terminated | pop_truncated\n", + "\n", + " return pop_fitness" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n", + " logger.warn(\n", + "100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "slow_time=14.704506635665894, fast_time=1.5114572048187256\n" + ] + }, + { + "data": { + "text/plain": [ + "(14.704506635665894, 1.5114572048187256)" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compare the speed between these two methods\n", + "# prerun once for jax compile\n", + "accelerated_evaluate_population(state, pop_nodes, pop_conns)\n", + "\n", + "import time\n", + "time_tic = time.time()\n", + "fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n", + "slow_time = time.time() - time_tic\n", + "\n", + "time_tic = time.time()\n", + "fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n", + "fast_time = time.time() - time_tic\n", + "\n", + "print(f\"{slow_time=}, {fast_time=}\")\n", + "slow_time, fast_time" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/tutorial_0_state.pkl b/tutorials/tutorial_0_state.pkl new file mode 100644 index 0000000..818bb0d Binary files /dev/null and b/tutorials/tutorial_0_state.pkl differ