526 lines
18 KiB
Plaintext
526 lines
18 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Tutorial 2: Usage\n",
|
||
"\n",
|
||
"This tutorial introduces how to use TensorNEAT to solve problems. \n",
|
||
"\n",
|
||
"TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline). \n",
|
||
"Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm. \n",
|
||
"The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.\n",
|
||
"\n",
|
||
"---\n",
|
||
"\n",
|
||
"## Types of Problems in TensorNEAT \n",
|
||
"\n",
|
||
"The problems to be solved using TensorNEAT can be categorized into the following cases:\n",
|
||
"\n",
|
||
"1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax) \n",
|
||
" - In this case, users can directly create a pipeline and execute it.\n",
|
||
"\n",
|
||
"2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`) \n",
|
||
" - Users need to create a **Custom Problem class**, then create a pipeline for execution.\n",
|
||
"\n",
|
||
"3. **Problems not provided by TensorNEAT and not JIT-compatible** \n",
|
||
" - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed. \n",
|
||
" - The detailed method for manual execution is explained below."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.1 Using Existing Benchmarks\n",
|
||
"\n",
|
||
"TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**. \n",
|
||
"\n",
|
||
"If you want to use these predefined problems, refer to the **examples** for implementation details."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.2 Custom Jitable Problem\n",
|
||
"The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"INPUTS.shape=(10, 2), LABELS.shape=(10, 1)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Prepartion\n",
|
||
"import jax, jax.numpy as jnp\n",
|
||
"from tensorneat.problem import BaseProblem\n",
|
||
"\n",
|
||
"# The problem is to fit pagie_polynomial\n",
|
||
"def pagie_polynomial(inputs):\n",
|
||
" x, y = inputs\n",
|
||
" res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))\n",
|
||
"\n",
|
||
" # Important! Returns an array with one item, NOT a scalar\n",
|
||
" return jnp.array([res])\n",
|
||
"\n",
|
||
"# Create dataset (10 samples)\n",
|
||
"INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))\n",
|
||
"LABELS = jax.vmap(pagie_polynomial)(INPUTS)\n",
|
||
"\n",
|
||
"print(f\"{INPUTS.shape=}, {LABELS.shape=}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Define the custom Problem\n",
|
||
"class CustomProblem(BaseProblem):\n",
|
||
"\n",
|
||
" jitable = True # necessary\n",
|
||
"\n",
|
||
" def evaluate(self, state, randkey, act_func, params):\n",
|
||
" # Use ``act_func(state, params, inputs)`` to do network forward\n",
|
||
"\n",
|
||
" # do batch forward for all inputs (using jax.vamp)\n",
|
||
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
|
||
" state, params, INPUTS\n",
|
||
" ) # should be shape (1000, 1)\n",
|
||
"\n",
|
||
" # calculate loss\n",
|
||
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
|
||
"\n",
|
||
" # return negative loss as fitness \n",
|
||
" # TensorNEAT maximizes fitness, equivalent to minimizes loss\n",
|
||
" return -loss\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def input_shape(self):\n",
|
||
" # the input shape that the act_func expects\n",
|
||
" return (2, )\n",
|
||
" \n",
|
||
" @property\n",
|
||
" def output_shape(self):\n",
|
||
" # the output shape that the act_func returns\n",
|
||
" return (1, )\n",
|
||
" \n",
|
||
" def show(self, state, randkey, act_func, params, *args, **kwargs):\n",
|
||
" # shocase the performance of one individual\n",
|
||
" predict = jax.vmap(act_func, in_axes=(None, None, 0))(\n",
|
||
" state, params, INPUTS\n",
|
||
" )\n",
|
||
"\n",
|
||
" loss = jnp.mean(jnp.square(predict - LABELS))\n",
|
||
"\n",
|
||
" msg = \"\"\n",
|
||
" for i in range(INPUTS.shape[0]):\n",
|
||
" msg += f\"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\\n\"\n",
|
||
" msg += f\"loss: {loss}\\n\"\n",
|
||
" print(msg)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"initializing\n",
|
||
"initializing finished\n",
|
||
"start compile\n",
|
||
"compile finished, cost time: 10.663079s\n",
|
||
"Generation: 1, Cost time: 64.31ms\n",
|
||
" \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
|
||
"\n",
|
||
"\tnode counts: max: 4, min: 3, mean: 3.10\n",
|
||
" \tconn counts: max: 3, min: 0, mean: 1.88\n",
|
||
" \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
|
||
"\n",
|
||
"Generation: 2, Cost time: 65.52ms\n",
|
||
" \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
|
||
"\n",
|
||
"\tnode counts: max: 5, min: 3, mean: 3.17\n",
|
||
" \tconn counts: max: 4, min: 0, mean: 1.87\n",
|
||
" \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
|
||
"\n",
|
||
"Generation: 3, Cost time: 59.10ms\n",
|
||
" \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
|
||
"\n",
|
||
"\tnode counts: max: 6, min: 3, mean: 3.49\n",
|
||
" \tconn counts: max: 6, min: 0, mean: 2.47\n",
|
||
" \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
|
||
"\n",
|
||
"Generation: 4, Cost time: 34.24ms\n",
|
||
" \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
|
||
"\n",
|
||
"\tnode counts: max: 6, min: 3, mean: 3.76\n",
|
||
" \tconn counts: max: 6, min: 0, mean: 2.66\n",
|
||
" \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
|
||
"\n",
|
||
"Generation: 5, Cost time: 20.36ms\n",
|
||
" \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
|
||
"\n",
|
||
"\tnode counts: max: 6, min: 3, mean: 3.94\n",
|
||
" \tconn counts: max: 6, min: 0, mean: 2.80\n",
|
||
" \tspecies: 20, [145, 150, 103, 148, 21, 36, 64, 48, 34, 26, 34, 36, 39, 7, 18, 26, 37, 10, 11, 7]\n",
|
||
"\n",
|
||
"Generation limit reached!\n",
|
||
"input: [0.85417664 0.16620052], target: [0.3481666], predict: [0.35990623]\n",
|
||
"input: [0.27605474 0.48728156], target: [0.0591442], predict: [0.12697154]\n",
|
||
"input: [0.9920441 0.03015983], target: [0.49201378], predict: [0.38432238]\n",
|
||
"input: [0.21629429 0.37687123], target: [0.02195805], predict: [0.02771863]\n",
|
||
"input: [0.63070035 0.96144474], target: [0.5973772], predict: [0.62119806]\n",
|
||
"input: [0.15203023 0.92090297], target: [0.4188713], predict: [0.26794043]\n",
|
||
"input: [0.30555236 0.29931295], target: [0.01660334], predict: [0.04903176]\n",
|
||
"input: [0.6925707 0.8542826], target: [0.5345536], predict: [0.6080159]\n",
|
||
"input: [0.46517384 0.7869307 ], target: [0.3219154], predict: [0.4150214]\n",
|
||
"input: [0.99605286 0.28018546], target: [0.5021702], predict: [0.5179908]\n",
|
||
"loss: 0.0055083888582885265\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import jax.numpy as jnp\n",
|
||
"\n",
|
||
"from tensorneat.pipeline import Pipeline\n",
|
||
"from tensorneat.algorithm.neat import NEAT\n",
|
||
"from tensorneat.genome import DefaultGenome, BiasNode\n",
|
||
"from tensorneat.problem.func_fit import CustomFuncFit\n",
|
||
"from tensorneat.common import ACT, AGG\n",
|
||
"\n",
|
||
"# Construct the pipeline and run\n",
|
||
"pipeline = Pipeline(\n",
|
||
" algorithm=NEAT(\n",
|
||
" pop_size=1000,\n",
|
||
" species_size=20,\n",
|
||
" survival_threshold=0.01,\n",
|
||
" genome=DefaultGenome(\n",
|
||
" num_inputs=2,\n",
|
||
" num_outputs=1,\n",
|
||
" init_hidden_layers=(),\n",
|
||
" node_gene=BiasNode(\n",
|
||
" activation_options=[ACT.identity, ACT.inv],\n",
|
||
" aggregation_options=[AGG.sum, AGG.product],\n",
|
||
" ),\n",
|
||
" output_transform=ACT.identity,\n",
|
||
" ),\n",
|
||
" ),\n",
|
||
" problem=CustomProblem(),\n",
|
||
" generation_limit=5,\n",
|
||
" fitness_target=-1e-4,\n",
|
||
" seed=42,\n",
|
||
")\n",
|
||
"\n",
|
||
"# initialize state\n",
|
||
"state = pipeline.setup()\n",
|
||
"# run until terminate\n",
|
||
"state, best = pipeline.auto_run(state)\n",
|
||
"# show result\n",
|
||
"pipeline.show(state, best)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2.3 Custom Un-Jitable Problem \n",
|
||
"This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# We use Cartpole in gymnasium as the Un-jitable problem\n",
|
||
"import gymnasium as gym\n",
|
||
"env = gym.make(\"CartPole-v1\")\n",
|
||
"\n",
|
||
"from tensorneat.common import State\n",
|
||
"# Define the genome and Pre jit necessary functions in genome\n",
|
||
"genome=DefaultGenome(\n",
|
||
" num_inputs=4,\n",
|
||
" num_outputs=2,\n",
|
||
" init_hidden_layers=(),\n",
|
||
" node_gene=BiasNode(),\n",
|
||
" output_transform=jnp.argmax,\n",
|
||
")\n",
|
||
"state = State(randkey=jax.random.key(0))\n",
|
||
"state = genome.setup(state)\n",
|
||
"\n",
|
||
"jit_transform = jax.jit(genome.transform)\n",
|
||
"jit_forward = jax.jit(genome.forward)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# define the method to evaluate the individual and the population\n",
|
||
"from tqdm import tqdm\n",
|
||
"\n",
|
||
"def evaluate(state, nodes, conns):\n",
|
||
" # evaluate the individual\n",
|
||
" transformed = jit_transform(state, nodes, conns)\n",
|
||
"\n",
|
||
" observation, info = env.reset()\n",
|
||
" episode_over, total_reward = False, 0\n",
|
||
" while not episode_over:\n",
|
||
" action = jit_forward(state, transformed, observation)\n",
|
||
" # currently the action is a jax array on gpu\n",
|
||
" # we need move it to cpu for env step\n",
|
||
" action = jax.device_get(action)\n",
|
||
"\n",
|
||
" observation, reward, terminated, truncated, info = env.step(action)\n",
|
||
" total_reward += reward\n",
|
||
" episode_over = terminated or truncated\n",
|
||
"\n",
|
||
" return total_reward\n",
|
||
"\n",
|
||
"def evaluate_population(state, pop_nodes, pop_conns):\n",
|
||
" # evaluate the population\n",
|
||
" pop_size = pop_nodes.shape[0]\n",
|
||
" fitness = []\n",
|
||
" for i in tqdm(range(pop_size)):\n",
|
||
" fitness.append(\n",
|
||
" evaluate(state, pop_nodes[i], pop_conns[i])\n",
|
||
" )\n",
|
||
"\n",
|
||
" # return a jax array\n",
|
||
" return jnp.asarray(fitness)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# define the algorithm\n",
|
||
"algorithm = NEAT(\n",
|
||
" pop_size=100,\n",
|
||
" species_size=20,\n",
|
||
" survival_threshold=0.1,\n",
|
||
" genome=genome,\n",
|
||
")\n",
|
||
"state = algorithm.setup(state)\n",
|
||
"\n",
|
||
"# jit for acceleration\n",
|
||
"jit_algorithm_ask = jax.jit(algorithm.ask)\n",
|
||
"jit_algorithm_tell = jax.jit(algorithm.tell)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Start running...\n",
|
||
"Generation 0: evaluating population...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 100/100 [00:04<00:00, 22.38it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Generation 0: best fitness: 493.0\n",
|
||
"Generation 1: evaluating population...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 100/100 [00:17<00:00, 5.62it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Generation 1: best fitness: 500.0\n",
|
||
"Fitness limit reached!\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# run!\n",
|
||
"print(\"Start running...\")\n",
|
||
"for generation in range(10):\n",
|
||
" pop_nodes, pop_conns = jit_algorithm_ask(state)\n",
|
||
" print(f\"Generation {generation}: evaluating population...\")\n",
|
||
" fitness = evaluate_population(state, pop_nodes, pop_conns)\n",
|
||
"\n",
|
||
" state = jit_algorithm_tell(state, fitness)\n",
|
||
" print(f\"Generation {generation}: best fitness: {fitness.max()}\")\n",
|
||
"\n",
|
||
" if fitness.max() >= 500:\n",
|
||
" print(\"Fitness limit reached!\")\n",
|
||
" break"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The above code runs slowly due to the following reasons:\n",
|
||
"1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.\n",
|
||
"2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.\n",
|
||
"3. There are too many switches between Python code and JAX code, causing unnecessary overhead.\n",
|
||
"\n",
|
||
"\n",
|
||
"The following code demonstrates an optimized `gymnasium` evaluation process:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# use numpy as numpy-python switch takes shorter time than jax-python switch\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))\n",
|
||
"jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))\n",
|
||
"\n",
|
||
"POP_SIZE = 100\n",
|
||
"# Use multiple envs\n",
|
||
"envs = [gym.make(\"CartPole-v1\") for _ in range(POP_SIZE)]\n",
|
||
"def accelerated_evaluate_population(state, pop_nodes, pop_conns):\n",
|
||
" # transformed the population using batch transfrom\n",
|
||
" pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)\n",
|
||
"\n",
|
||
" pop_observation = [env.reset()[0] for env in envs]\n",
|
||
" pop_observation = np.asarray(pop_observation)\n",
|
||
" pop_fitness = np.zeros(POP_SIZE)\n",
|
||
" episode_over = np.zeros(POP_SIZE, dtype=bool)\n",
|
||
" \n",
|
||
" while not np.all(episode_over):\n",
|
||
" # batch forward\n",
|
||
" pop_action = jit_batch_forward(state, pop_transformed, pop_observation)\n",
|
||
" pop_action = jax.device_get(pop_action)\n",
|
||
"\n",
|
||
" obs, reward, terminated, truncated = [], [], [], []\n",
|
||
" # we still need to step the envs one by one\n",
|
||
" for i in range(POP_SIZE):\n",
|
||
" if episode_over[i]:\n",
|
||
" # is already terminated\n",
|
||
" # append zeros to keep the shape\n",
|
||
" obs.append(np.zeros(4))\n",
|
||
" reward.append(0.0)\n",
|
||
" terminated.append(True)\n",
|
||
" truncated.append(False)\n",
|
||
" continue\n",
|
||
" else:\n",
|
||
" # step the env\n",
|
||
" obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
|
||
" obs.append(obs_)\n",
|
||
" reward.append(reward_)\n",
|
||
" terminated.append(terminated_)\n",
|
||
" truncated.append(truncated_)\n",
|
||
"\n",
|
||
" pop_observation = np.asarray(obs)\n",
|
||
" pop_reward = np.asarray(reward)\n",
|
||
" pop_terminated = np.asarray(terminated)\n",
|
||
" pop_truncated = np.asarray(truncated)\n",
|
||
"\n",
|
||
" # update fitness and over\n",
|
||
" pop_fitness += pop_reward * ~episode_over\n",
|
||
" episode_over = episode_over | pop_terminated | pop_truncated\n",
|
||
"\n",
|
||
" return pop_fitness"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 100/100 [00:14<00:00, 6.87it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"slow_time=14.562041997909546, fast_time=1.134758710861206\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Compare the speed between these two methods\n",
|
||
"# prerun once for jax compile\n",
|
||
"accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
||
"\n",
|
||
"import time\n",
|
||
"time_tic = time.time()\n",
|
||
"fitness_slow = evaluate_population(state, pop_nodes, pop_conns)\n",
|
||
"slow_time = time.time() - time_tic\n",
|
||
"\n",
|
||
"time_tic = time.time()\n",
|
||
"fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
||
"fast_time = time.time() - time_tic\n",
|
||
"\n",
|
||
"print(f\"{slow_time=}, {fast_time=}\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "jax_env",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.14"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|