Add tutorials

This commit is contained in:
wls2002
2025-01-30 16:53:24 +08:00
parent ee1a2a8271
commit f67c69776a
8 changed files with 1299 additions and 12 deletions

View File

@@ -1,21 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7c6d7313",
"metadata": {},
"source": [
"# Tutorial 1: Genome\n",
"The genome is the core component of TensorNEAT. It represents the networks genotype."
]
},
{
"cell_type": "markdown",
"id": "939c40b4",
"metadata": {},
"source": [
"\n",
"Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
"execution_count": 66,
"id": "3b1836c3",
"metadata": {},
"outputs": [],
"source": [
"from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
"\n",
"genome = DefaultGenome(\n",
" num_inputs=3, # 3 inputs\n",
" num_outputs=1, # 1 output\n",
" max_nodes=5, # the network will have at most 5 nodes\n",
" max_conns=10, # the network will have at most 10 connections\n",
" node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
" conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
" mutation=DefaultMutation(\n",
" node_add=0.9,\n",
" node_delete=0.0,\n",
" conn_add=0.9,\n",
" conn_delete=0.0\n",
" ) # high mutation rate for testing\n",
")\n",
"\n",
"state = genome.setup()"
]
},
{
"cell_type": "markdown",
"id": "f64c565c",
"metadata": {},
"source": [
"After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "ef66ba76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
"batch_output=Array([[ 2.5477247 ],\n",
" [ 2.2334106 ],\n",
" [ 1.8713341 ],\n",
" [-3.7539673 ],\n",
" [ 1.5344429 ],\n",
" [ 2.7640016 ],\n",
" [ 0.5649997 ],\n",
" [-0.32709932],\n",
" [ 3.5273829 ],\n",
" [-0.64774114]], dtype=float32, weak_type=True)\n"
]
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"# Initialize a network\n",
"nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
"\n",
"# Network forward\n",
"single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
"transformed = genome.transform(state, nodes, conns)\n",
"output = genome.forward(state, transformed, single_input)\n",
"print(f\"{output=}\")\n",
"\n",
"# Network batch forward\n",
"batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
"batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
"print(f\"{batch_output=}\")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "85433b5e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pop_outputs.shape=(1000, 1)\n"
]
}
],
"source": [
"# Initialize a population of networks (1000 networks)\n",
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
" state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
")\n",
"\n",
"# Population forward\n",
"pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
"pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
" state, pop_nodes, pop_conns\n",
")\n",
"pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
" state, pop_transformed, pop_inputs\n",
")\n",
"\n",
"print(f\"{pop_outputs.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "f1f9df60",
"metadata": {},
"outputs": [],
"source": [
"# visualize the network\n",
"network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
"genome.visualize(network, save_path=\"./origin_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "da5ad8ae",
"metadata": {},
"source": [
"![origin_network](./origin_network.svg)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "16966b69",
"metadata": {},
"outputs": [],
"source": [
"# Mutate the network\n",
"mutated_nodes, mutated_conns = genome.execute_mutation(\n",
" state,\n",
" jax.random.PRNGKey(2),\n",
" nodes,\n",
" conns,\n",
" new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
" new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
")\n",
"\n",
"# Visualize the mutated network\n",
"mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
"genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "c4367424",
"metadata": {},
"source": [
"![mutated_networ](./mutated_network.svg)"
]
},
{
"cell_type": "markdown",
"id": "f9922d34",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "",
"name": ""
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"name": ""
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,