{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 2: Usage\n", "\n", "This tutorial introduces how to use TensorNEAT to solve problems. \n", "\n", "TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n", "Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n", "The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n", "\n", "---\n", "\n", "## Types of Problems in TensorNEAT \n", "\n", "The problems to be solved using TensorNEAT can be categorized into the following cases:\n", "\n", "1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n", " - In this case, users can directly create a pipeline and execute it.\n", "\n", "2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n", " - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n", "\n", "3. **Problems not provided by TensorNEAT and not JIT-compatible** \n", " - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n", " - The detailed method for manual execution is explained below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.1 Using Existing Benchmarks\n", "\n", "TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n", "\n", "If you want to use these predefined problems, refer to the **examples** for implementation details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.2 Custom Jitable Problem\n", "The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n" ] } ], "source": [ "# Prepartion\n", "import jax, jax.numpy as jnp\n", "from tensorneat.problem import BaseProblem\n", "\n", "# The problem is to fit pagie_polynomial\n", "def pagie_polynomial(inputs):\n", " x, y = inputs\n", " res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n", "\n", " # Important! Returns an array with one item, NOT a scalar\n", " return jnp.array([res])\n", "\n", "# Create dataset (10 samples)\n", "INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n", "LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n", "\n", "print(f\"{INPUTS.shape=}, {LABELS.shape=}\")" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "# Define the custom Problem\n", "class CustomProblem(BaseProblem):\n", "\n", " jitable = True # necessary\n", "\n", " def evaluate(self, state, randkey, act_func, params):\n", " # Use ``act_func(state, params, inputs)`` to do network forward\n", "\n", " # do batch forward for all inputs (using jax.vamp)\n", " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n", " state, params, INPUTS\n", " ) # should be shape (1000, 1)\n", "\n", " # calculate loss\n", " loss = jnp.mean(jnp.square(predict - LABELS))\n", "\n", " # return negative loss as fitness \n", " # TensorNEAT maximizes fitness, equivalent to minimizes loss\n", " return -loss\n", "\n", " @property\n", " def input_shape(self):\n", " # the input shape that the act_func expects\n", " return (2, )\n", " \n", " @property\n", " def output_shape(self):\n", " # the output shape that the act_func returns\n", " return (1, )\n", " \n", " def show(self, state, randkey, act_func, params, *args, **kwargs):\n", " # shocase the performance of one individual\n", " predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n", " state, params, INPUTS\n", " )\n", "\n", " loss = jnp.mean(jnp.square(predict - LABELS))\n", "\n", " msg = \"\"\n", " for i in range(INPUTS.shape[0]):\n", " msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n", " msg += f\"loss: {loss}\\n\"\n", " print(msg)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initializing\n", "initializing finished\n", "start compile\n", "compile finished, cost time: 10.318278s\n", "Generation: 1, Cost time: 45.01ms\n", " \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n", "\n", "\tnode counts: max: 4, min: 3, mean: 3.10\n", " \tconn counts: max: 3, min: 0, mean: 1.88\n", " \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n", "\n", "Generation: 2, Cost time: 53.77ms\n", " \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n", "\n", "\tnode counts: max: 5, min: 3, mean: 3.17\n", " \tconn counts: max: 4, min: 0, mean: 1.87\n", " \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n", "\n", "Generation: 3, Cost time: 21.86ms\n", " \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.49\n", " \tconn counts: max: 6, min: 0, mean: 2.47\n", " \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n", "\n", "Generation: 4, Cost time: 24.30ms\n", " \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.76\n", " \tconn counts: max: 6, min: 0, mean: 2.66\n", " \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n", "\n", "Generation: 5, Cost time: 27.68ms\n", " \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.94\n", " \tconn counts: max: 6, min: 0, mean: 2.80\n", " \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n", "\n", "Generation limit reached!\n", "input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n", "input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n", "input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n", "input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n", "input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n", "input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n", "input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n", "input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n", "input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n", "input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n", "loss: 0.0055083888582885265\n", "\n" ] } ], "source": [ "import jax.numpy as jnp\n", "\n", "from tensorneat.pipeline import Pipeline\n", "from tensorneat.algorithm.neat import NEAT\n", "from tensorneat.genome import DefaultGenome, BiasNode\n", "from tensorneat.problem.func_fit import CustomFuncFit\n", "from tensorneat.common import ACT, AGG\n", "\n", "# Construct the pipeline and run\n", "pipeline = Pipeline(\n", " algorithm=NEAT(\n", " pop_size=1000,\n", " species_size=20,\n", " survival_threshold=0.01,\n", " genome=DefaultGenome(\n", " num_inputs=2,\n", " num_outputs=1,\n", " init_hidden_layers=(),\n", " node_gene=BiasNode(\n", " activation_options=[ACT.identity, ACT.inv],\n", " aggregation_options=[AGG.sum, AGG.product],\n", " ),\n", " output_transform=ACT.identity,\n", " ),\n", " ),\n", " problem=CustomProblem(),\n", " generation_limit=5,\n", " fitness_target=-1e-4,\n", " seed=42,\n", ")\n", "\n", "# initialize state\n", "state = pipeline.setup()\n", "# run until terminate\n", "state, best = pipeline.auto_run(state)\n", "# show result\n", "pipeline.show(state, best)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Custom Un-Jitable Problem \n", "This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem." ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "# We use Cartpole in gymnasium as the Un-jitable problem\n", "import gymnasium as gym\n", "env = gym.make(\"CartPole-v1\")\n", "\n", "from tensorneat.common import State\n", "# Define the genome and Pre jit necessary functions in genome\n", "genome=DefaultGenome(\n", " num_inputs=4,\n", " num_outputs=2,\n", " init_hidden_layers=(),\n", " node_gene=BiasNode(),\n", " output_transform=jnp.argmax,\n", ")\n", "state = State(randkey=jax.random.key(0))\n", "state = genome.setup(state)\n", "\n", "jit_transform = jax.jit(genome.transform)\n", "jit_forward = jax.jit(genome.forward)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "# define the method to evaluate the individual and the population\n", "from tqdm import tqdm\n", "\n", "def evaluate(state, nodes, conns):\n", " # evaluate the individual\n", " transformed = jit_transform(state, nodes, conns)\n", "\n", " observation, info = env.reset()\n", " episode_over, total_reward = False, 0\n", " while not episode_over:\n", " action = jit_forward(state, transformed, observation)\n", " # currently the action is a jax array on gpu\n", " # we need move it to cpu for env step\n", " action = jax.device_get(action)\n", "\n", " observation, reward, terminated, truncated, info = env.step(action)\n", " total_reward += reward\n", " episode_over = terminated or truncated\n", "\n", " return total_reward\n", "\n", "def evaluate_population(state, pop_nodes, pop_conns):\n", " # evaluate the population\n", " pop_size = pop_nodes.shape[0]\n", " fitness = []\n", " for i in tqdm(range(pop_size)):\n", " fitness.append(\n", " evaluate(state, pop_nodes[i], pop_conns[i])\n", " )\n", "\n", " # return a jax array\n", " return jnp.asarray(fitness)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [], "source": [ "# define the algorithm\n", "algorithm = NEAT(\n", " pop_size=100,\n", " species_size=20,\n", " survival_threshold=0.1,\n", " genome=genome,\n", ")\n", "state = algorithm.setup(state)\n", "\n", "# jit for acceleration\n", "jit_algorithm_ask = jax.jit(algorithm.ask)\n", "jit_algorithm_tell = jax.jit(algorithm.tell)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Start running...\n", "Generation 0: evaluating population...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generation 0: best fitness: 384.0\n", "Generation 1: evaluating population...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generation 1: best fitness: 500.0\n", "Fitness limit reached!\n" ] } ], "source": [ "# run!\n", "print(\"Start running...\")\n", "for generation in range(10):\n", " pop_nodes, pop_conns = jit_algorithm_ask(state)\n", " print(f\"Generation {generation}: evaluating population...\")\n", " fitness = evaluate_population(state, pop_nodes, pop_conns)\n", "\n", " state = jit_algorithm_tell(state, fitness)\n", " print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n", "\n", " if fitness.max() >= 500:\n", " print(\"Fitness limit reached!\")\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above code runs slowly due to the following reasons:\n", "1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n", "2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.\n", "3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n", "\n", "\n", "The following code demonstrates an optimized `gymnasium` evaluation process:" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [], "source": [ "# use numpy as numpy-python switch takes shorter time than jax-python switch\n", "import numpy as np\n", "\n", "jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n", "jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n", "\n", "POP_SIZE = 100\n", "# Use multiple envs\n", "envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n", "def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n", " # transformed the population using batch transfrom\n", " pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n", "\n", " pop_observation = [env.reset()[0] for env in envs]\n", " pop_observation = np.asarray(pop_observation)\n", " pop_fitness = np.zeros(POP_SIZE)\n", " episode_over = np.zeros(POP_SIZE, dtype=bool)\n", " \n", " while not np.all(episode_over):\n", " # batch forward\n", " pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n", " pop_action = jax.device_get(pop_action)\n", "\n", " obs, reward, terminated, truncated = [], [], [], []\n", " # we still need to step the envs one by one\n", " for i in range(POP_SIZE):\n", " obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n", " obs.append(obs_)\n", " reward.append(reward_)\n", " terminated.append(terminated_)\n", " truncated.append(truncated_)\n", "\n", " pop_observation = np.asarray(obs)\n", " pop_reward = np.asarray(reward)\n", " pop_terminated = np.asarray(terminated)\n", " pop_truncated = np.asarray(truncated)\n", "\n", " # update fitness and over\n", " pop_fitness += pop_reward * ~episode_over\n", " episode_over = episode_over | pop_terminated | pop_truncated\n", "\n", " return pop_fitness" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n", " logger.warn(\n", "100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "slow_time=14.704506635665894, fast_time=1.5114572048187256\n" ] }, { "data": { "text/plain": [ "(14.704506635665894, 1.5114572048187256)" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compare the speed between these two methods\n", "# prerun once for jax compile\n", "accelerated_evaluate_population(state, pop_nodes, pop_conns)\n", "\n", "import time\n", "time_tic = time.time()\n", "fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n", "slow_time = time.time() - time_tic\n", "\n", "time_tic = time.time()\n", "fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n", "fast_time = time.time() - time_tic\n", "\n", "print(f\"{slow_time=}, {fast_time=}\")\n", "slow_time, fast_time" ] } ], "metadata": { "kernelspec": { "display_name": "jax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }