Files
tensorneat-mend/tutorials/tutorial-01-genome.ipynb
2025-01-30 16:53:24 +08:00

209 lines
5.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "7c6d7313",
"metadata": {},
"source": [
"# Tutorial 1: Genome\n",
"The genome is the core component of TensorNEAT. It represents the networks genotype."
]
},
{
"cell_type": "markdown",
"id": "939c40b4",
"metadata": {},
"source": [
"\n",
"Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "3b1836c3",
"metadata": {},
"outputs": [],
"source": [
"from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
"\n",
"genome = DefaultGenome(\n",
" num_inputs=3, # 3 inputs\n",
" num_outputs=1, # 1 output\n",
" max_nodes=5, # the network will have at most 5 nodes\n",
" max_conns=10, # the network will have at most 10 connections\n",
" node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
" conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
" mutation=DefaultMutation(\n",
" node_add=0.9,\n",
" node_delete=0.0,\n",
" conn_add=0.9,\n",
" conn_delete=0.0\n",
" ) # high mutation rate for testing\n",
")\n",
"\n",
"state = genome.setup()"
]
},
{
"cell_type": "markdown",
"id": "f64c565c",
"metadata": {},
"source": [
"After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "ef66ba76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
"batch_output=Array([[ 2.5477247 ],\n",
" [ 2.2334106 ],\n",
" [ 1.8713341 ],\n",
" [-3.7539673 ],\n",
" [ 1.5344429 ],\n",
" [ 2.7640016 ],\n",
" [ 0.5649997 ],\n",
" [-0.32709932],\n",
" [ 3.5273829 ],\n",
" [-0.64774114]], dtype=float32, weak_type=True)\n"
]
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"\n",
"# Initialize a network\n",
"nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
"\n",
"# Network forward\n",
"single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
"transformed = genome.transform(state, nodes, conns)\n",
"output = genome.forward(state, transformed, single_input)\n",
"print(f\"{output=}\")\n",
"\n",
"# Network batch forward\n",
"batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
"batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
"print(f\"{batch_output=}\")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "85433b5e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pop_outputs.shape=(1000, 1)\n"
]
}
],
"source": [
"# Initialize a population of networks (1000 networks)\n",
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
" state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
")\n",
"\n",
"# Population forward\n",
"pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
"pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
" state, pop_nodes, pop_conns\n",
")\n",
"pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
" state, pop_transformed, pop_inputs\n",
")\n",
"\n",
"print(f\"{pop_outputs.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "f1f9df60",
"metadata": {},
"outputs": [],
"source": [
"# visualize the network\n",
"network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
"genome.visualize(network, save_path=\"./origin_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "da5ad8ae",
"metadata": {},
"source": [
"![origin_network](./origin_network.svg)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "16966b69",
"metadata": {},
"outputs": [],
"source": [
"# Mutate the network\n",
"mutated_nodes, mutated_conns = genome.execute_mutation(\n",
" state,\n",
" jax.random.PRNGKey(2),\n",
" nodes,\n",
" conns,\n",
" new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
" new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
")\n",
"\n",
"# Visualize the mutated network\n",
"mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
"genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
]
},
{
"cell_type": "markdown",
"id": "c4367424",
"metadata": {},
"source": [
"![mutated_networ](./mutated_network.svg)"
]
},
{
"cell_type": "markdown",
"id": "f9922d34",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}