209 lines
5.8 KiB
Plaintext
209 lines
5.8 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7c6d7313",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Tutorial 1: Genome\n",
|
||
"The genome is the core component of TensorNEAT. It represents the network’s genotype."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "939c40b4",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"id": "3b1836c3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation\n",
|
||
"\n",
|
||
"genome = DefaultGenome(\n",
|
||
" num_inputs=3, # 3 inputs\n",
|
||
" num_outputs=1, # 1 output\n",
|
||
" max_nodes=5, # the network will have at most 5 nodes\n",
|
||
" max_conns=10, # the network will have at most 10 connections\n",
|
||
" node_gene=BiasNode(), # node with 3 attributes: bias, aggregation, activation\n",
|
||
" conn_gene=DefaultConn(), # connection with 1 attribute: weight\n",
|
||
" mutation=DefaultMutation(\n",
|
||
" node_add=0.9,\n",
|
||
" node_delete=0.0,\n",
|
||
" conn_add=0.9,\n",
|
||
" conn_delete=0.0\n",
|
||
" ) # high mutation rate for testing\n",
|
||
")\n",
|
||
"\n",
|
||
"state = genome.setup()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f64c565c",
|
||
"metadata": {},
|
||
"source": [
|
||
"After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"id": "ef66ba76",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"output=Array([-1.5388116], dtype=float32, weak_type=True)\n",
|
||
"batch_output=Array([[ 2.5477247 ],\n",
|
||
" [ 2.2334106 ],\n",
|
||
" [ 1.8713341 ],\n",
|
||
" [-3.7539673 ],\n",
|
||
" [ 1.5344429 ],\n",
|
||
" [ 2.7640016 ],\n",
|
||
" [ 0.5649997 ],\n",
|
||
" [-0.32709932],\n",
|
||
" [ 3.5273829 ],\n",
|
||
" [-0.64774114]], dtype=float32, weak_type=True)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import jax, jax.numpy as jnp\n",
|
||
"\n",
|
||
"# Initialize a network\n",
|
||
"nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))\n",
|
||
"\n",
|
||
"# Network forward\n",
|
||
"single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))\n",
|
||
"transformed = genome.transform(state, nodes, conns)\n",
|
||
"output = genome.forward(state, transformed, single_input)\n",
|
||
"print(f\"{output=}\")\n",
|
||
"\n",
|
||
"# Network batch forward\n",
|
||
"batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))\n",
|
||
"batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)\n",
|
||
"print(f\"{batch_output=}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 68,
|
||
"id": "85433b5e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"pop_outputs.shape=(1000, 1)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Initialize a population of networks (1000 networks)\n",
|
||
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(\n",
|
||
" state, jax.random.split(jax.random.PRNGKey(0), 1000)\n",
|
||
")\n",
|
||
"\n",
|
||
"# Population forward\n",
|
||
"pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))\n",
|
||
"pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(\n",
|
||
" state, pop_nodes, pop_conns\n",
|
||
")\n",
|
||
"pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(\n",
|
||
" state, pop_transformed, pop_inputs\n",
|
||
")\n",
|
||
"\n",
|
||
"print(f\"{pop_outputs.shape=}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"id": "f1f9df60",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# visualize the network\n",
|
||
"network = genome.network_dict(state, nodes, conns) # Transform the network from JAX arrays to a Python dict\n",
|
||
"genome.visualize(network, save_path=\"./origin_network.svg\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "da5ad8ae",
|
||
"metadata": {},
|
||
"source": [
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 70,
|
||
"id": "16966b69",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Mutate the network\n",
|
||
"mutated_nodes, mutated_conns = genome.execute_mutation(\n",
|
||
" state,\n",
|
||
" jax.random.PRNGKey(2),\n",
|
||
" nodes,\n",
|
||
" conns,\n",
|
||
" new_node_key=jnp.asarray(5), # at most 1 node can be added in each mutation\n",
|
||
" new_conn_keys=jnp.asarray([6, 7, 8]), # at most 3 connections can be added\n",
|
||
")\n",
|
||
"\n",
|
||
"# Visualize the mutated network\n",
|
||
"mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)\n",
|
||
"genome.visualize(mutated_network, save_path=\"./mutated_network.svg\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c4367424",
|
||
"metadata": {},
|
||
"source": [
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f9922d34",
|
||
"metadata": {},
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "jax_env",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.14"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|