{ "cells": [ { "cell_type": "markdown", "id": "7c6d7313", "metadata": {}, "source": [ "# Tutorial 1: Genome\n", "The genome is the core component of TensorNEAT. It represents the network’s genotype." ] }, { "cell_type": "markdown", "id": "939c40b4", "metadata": {}, "source": [ "\n", "Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize." ] }, { "cell_type": "code", "execution_count": 66, "id": "3b1836c3", "metadata": {}, "outputs": [], "source": [ "from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n", "\n", "genome = DefaultGenome(\n", " num_inputs=3, # 3 inputs\n", " num_outputs=1, # 1 output\n", " max_nodes=5, # the network will have at most 5 nodes\n", " max_conns=10, # the network will have at most 10 connections\n", " node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n", " conn_gene=DefaultConn(), # connection with 1 attribute: weight\n", " mutation=DefaultMutation(\n", " node_add=0.9,\n", " node_delete=0.0,\n", " conn_add=0.9,\n", " conn_delete=0.0\n", " ) # high mutation rate for testing\n", ")\n", "\n", "state = genome.setup()" ] }, { "cell_type": "markdown", "id": "f64c565c", "metadata": {}, "source": [ "After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`." ] }, { "cell_type": "code", "execution_count": 67, "id": "ef66ba76", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output=Array([-1.5388116], dtype=float32, weak_type=True)\n", "batch_output=Array([[ 2.5477247 ],\n", " [ 2.2334106 ],\n", " [ 1.8713341 ],\n", " [-3.7539673 ],\n", " [ 1.5344429 ],\n", " [ 2.7640016 ],\n", " [ 0.5649997 ],\n", " [-0.32709932],\n", " [ 3.5273829 ],\n", " [-0.64774114]], dtype=float32, weak_type=True)\n" ] } ], "source": [ "import jax, jax.numpy as jnp\n", "\n", "# Initialize a network\n", "nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n", "\n", "# Network forward\n", "single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n", "transformed = genome.transform(state, nodes, conns)\n", "output = genome.forward(state, transformed, single_input)\n", "print(f\"{output=}\")\n", "\n", "# Network batch forward\n", "batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n", "batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n", "print(f\"{batch_output=}\")" ] }, { "cell_type": "code", "execution_count": 68, "id": "85433b5e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pop_outputs.shape=(1000, 1)\n" ] } ], "source": [ "# Initialize a population of networks (1000 networks)\n", "pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n", " state, jax.random.split(jax.random.PRNGKey(0), 1000)\n", ")\n", "\n", "# Population forward\n", "pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n", "pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n", " state, pop_nodes, pop_conns\n", ")\n", "pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n", " state, pop_transformed, pop_inputs\n", ")\n", "\n", "print(f\"{pop_outputs.shape=}\")" ] }, { "cell_type": "code", "execution_count": 69, "id": "f1f9df60", "metadata": {}, "outputs": [], "source": [ "# visualize the network\n", "network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n", "genome.visualize(network, save_path=\"./origin_network.svg\")" ] }, { "cell_type": "markdown", "id": "da5ad8ae", "metadata": {}, "source": [ "![origin_network](./origin_network.svg)" ] }, { "cell_type": "code", "execution_count": 70, "id": "16966b69", "metadata": {}, "outputs": [], "source": [ "# Mutate the network\n", "mutated_nodes, mutated_conns = genome.execute_mutation(\n", " state,\n", " jax.random.PRNGKey(2),\n", " nodes,\n", " conns,\n", " new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n", " new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n", ")\n", "\n", "# Visualize the mutated network\n", "mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n", "genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")" ] }, { "cell_type": "markdown", "id": "c4367424", "metadata": {}, "source": [ "![mutated_networ](./mutated_network.svg)" ] }, { "cell_type": "markdown", "id": "f9922d34", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "jax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }