{ "cells": [ { "cell_type": "markdown", "id": "3b568ac2", "metadata": {}, "source": [ "# Tutorial 0: Functional Programming and State" ] }, { "cell_type": "markdown", "id": "302f305f", "metadata": {}, "source": [ "## Functional Programming\n", "\n", "TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n", "\n", "Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n", "\n", "1. **Pure Functions**: The same input always produces the same output, with no side effects.\n", "2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n", "3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values." ] }, { "cell_type": "markdown", "id": "ee0c8749", "metadata": {}, "source": [ "## State\n", "\n", "In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n", "\n", "Here are some usages about `State`." ] }, { "cell_type": "code", "execution_count": 1, "id": "557a2f24", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "state_a=State ({})\n", "state_b=State ({'a': 1, 'b': 2})\n", "state_b.a=1\n", "state_b.b=2\n" ] } ], "source": [ "# import State\n", "from tensorneat.common import State\n", "\n", "# create a new state\n", "state_a = State() # no arguments\n", "state_b = State(a=1, b=2) # kwargs\n", "\n", "print(f\"{state_a=}\")\n", "print(f\"{state_b=}\")\n", "\n", "# get items from state, use dot notation\n", "print(f\"{state_b.a=}\")\n", "print(f\"{state_b.b=}\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "2fd169ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "state_a=State ({'a': 1, 'b': 2})\n" ] } ], "source": [ "# add new items to the state, use register\n", "state_a = state_a.register(a=1, b=2)\n", "print(f\"{state_a=}\")\n", "\n", "# We CANNOT register the existing item\n", "# state_a = state_a.register(a=1)\n", "# will raise ValueError(f\"Key {key} already exists in state\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "8fe0d395", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "state_a=State ({'a': 3, 'b': 4})\n" ] } ], "source": [ "# update the value of an item, use update\n", "state_a = state_a.update(a=3, b=4)\n", "print(f\"{state_a=}\")\n", "\n", "# We CANNOT update the non-existing item\n", "# state_a = state_a.update(c=3)\n", "# will raise ValueError(f\"Key {key} does not exist in state\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "b3ed0eed", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "origin_state=State ({'a': 1, 'b': 2})\n", "new_state=State ({'a': 3, 'b': 2})\n" ] } ], "source": [ "# State is immutable! We always create a new state, rather than modifying the existing one.\n", "\n", "origin_state = State(a=1, b=2)\n", "new_state = origin_state.update(a=3)\n", "print(f\"{origin_state=}\") # origin_state is not changed\n", "print(f\"{new_state=}\")\n", "\n", "# We can not modify the state directly\n", "# origin_state.a = 3\n", "# will raise AttributeError: AttributeError(\"State is immutable\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "8c73e60c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n" ] } ], "source": [ "# State can be used in JAX functions\n", "import jax\n", "\n", "\n", "@jax.jit\n", "def func(state):\n", " c = state.a + state.b # fetch items from state\n", " state = state.update(a=c, b=c) # update items in state\n", " state = state.register(c=c) # add new item to state\n", " return state # return state\n", "\n", "\n", "new_state = func(state_a)\n", "print(f\"{new_state=}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "2bd732ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n" ] } ], "source": [ "# Save the state (use pickle) as file and load it.\n", "state = State(a=1, b=2, c=3, d=4)\n", "state.save(\"tutorial_0_state.pkl\")\n", "loaded_state = State.load(\"tutorial_0_state.pkl\")\n", "print(f\"{loaded_state=}\")" ] }, { "cell_type": "markdown", "id": "846b5374", "metadata": {}, "source": [ "## Objects in TensorNEAT\n", "\n", "In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n", "\n", "In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n", "\n", "For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!" ] }, { "cell_type": "code", "execution_count": 7, "id": "dac2bee6", "metadata": {}, "outputs": [], "source": [ "from tensorneat.genome import DefaultGenome\n", "\n", "genome = DefaultGenome(\n", " num_inputs=3,\n", " num_outputs=1,\n", " max_nodes=5,\n", " max_conns=5,\n", ")" ] }, { "cell_type": "markdown", "id": "2c196cc4", "metadata": {}, "source": [ "`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n", "\n", "To create a genome that can participate in calculation, we need to do following things." ] }, { "cell_type": "code", "execution_count": 8, "id": "7c597c04", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n", " [ 1. , 0.9807121 , 1. , 0. , 0. ],\n", " [ 2. , -0.8425486 , 1. , 0. , 0. ],\n", " [ 3. , -0.53765106, 1. , 0. , 0. ],\n", " [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n", "conns=Array([[0. , 3. , 0.785558 ],\n", " [1. , 3. , 2.3734226 ],\n", " [2. , 3. , 0.07902155],\n", " [ nan, nan, nan],\n", " [ nan, nan, nan]], dtype=float32, weak_type=True)\n" ] } ], "source": [ "# setup the genome, let the genome class store some useful information in State\n", "state = genome.setup()\n", "\n", "# create a new genome\n", "randkey = jax.random.key(0)\n", "nodes, conns = genome.initialize(state, randkey)\n", "print(f\"{nodes=}\")\n", "print(f\"{conns=}\")\n", " " ] }, { "cell_type": "code", "execution_count": 9, "id": "1d8ae137", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "outputs=Array([5.231817], dtype=float32, weak_type=True)\n" ] } ], "source": [ "# calculate\n", "inputs = jax.numpy.array([1, 2, 3])\n", "\n", "transformed = genome.transform(state, nodes, conns)\n", "outputs = genome.forward(state, transformed, inputs)\n", "\n", "print(f\"{outputs=}\")" ] }, { "cell_type": "markdown", "id": "7f01ed5d", "metadata": {}, "source": [ "## Conclusion\n", "1. TensorNEAT use functional programming paradiam.\n", "2. TensorNEAT provides `State` to manage data.\n", "3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data." ] } ], "metadata": { "kernelspec": { "display_name": "jax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }