335 lines
9.0 KiB
Plaintext
335 lines
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3b568ac2",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Tutorial 0: Functional Programming and State"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "302f305f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Functional Programming\n",
|
|
"\n",
|
|
"TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).\n",
|
|
"\n",
|
|
"Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:\n",
|
|
"\n",
|
|
"1. **Pure Functions**: The same input always produces the same output, with no side effects.\n",
|
|
"2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.\n",
|
|
"3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ee0c8749",
|
|
"metadata": {},
|
|
"source": [
|
|
"## State\n",
|
|
"\n",
|
|
"In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.\n",
|
|
"\n",
|
|
"Here are some usages about `State`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "557a2f24",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"state_a=State ({})\n",
|
|
"state_b=State ({'a': 1, 'b': 2})\n",
|
|
"state_b.a=1\n",
|
|
"state_b.b=2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# import State\n",
|
|
"from tensorneat.common import State\n",
|
|
"\n",
|
|
"# create a new state\n",
|
|
"state_a = State() # no arguments\n",
|
|
"state_b = State(a=1, b=2) # kwargs\n",
|
|
"\n",
|
|
"print(f\"{state_a=}\")\n",
|
|
"print(f\"{state_b=}\")\n",
|
|
"\n",
|
|
"# get items from state, use dot notation\n",
|
|
"print(f\"{state_b.a=}\")\n",
|
|
"print(f\"{state_b.b=}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "2fd169ea",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"state_a=State ({'a': 1, 'b': 2})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# add new items to the state, use register\n",
|
|
"state_a = state_a.register(a=1, b=2)\n",
|
|
"print(f\"{state_a=}\")\n",
|
|
"\n",
|
|
"# We CANNOT register the existing item\n",
|
|
"# state_a = state_a.register(a=1)\n",
|
|
"# will raise ValueError(f\"Key {key} already exists in state\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8fe0d395",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"state_a=State ({'a': 3, 'b': 4})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# update the value of an item, use update\n",
|
|
"state_a = state_a.update(a=3, b=4)\n",
|
|
"print(f\"{state_a=}\")\n",
|
|
"\n",
|
|
"# We CANNOT update the non-existing item\n",
|
|
"# state_a = state_a.update(c=3)\n",
|
|
"# will raise ValueError(f\"Key {key} does not exist in state\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "b3ed0eed",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"origin_state=State ({'a': 1, 'b': 2})\n",
|
|
"new_state=State ({'a': 3, 'b': 2})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# State is immutable! We always create a new state, rather than modifying the existing one.\n",
|
|
"\n",
|
|
"origin_state = State(a=1, b=2)\n",
|
|
"new_state = origin_state.update(a=3)\n",
|
|
"print(f\"{origin_state=}\") # origin_state is not changed\n",
|
|
"print(f\"{new_state=}\")\n",
|
|
"\n",
|
|
"# We can not modify the state directly\n",
|
|
"# origin_state.a = 3\n",
|
|
"# will raise AttributeError: AttributeError(\"State is immutable\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "8c73e60c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# State can be used in JAX functions\n",
|
|
"import jax\n",
|
|
"\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def func(state):\n",
|
|
" c = state.a + state.b # fetch items from state\n",
|
|
" state = state.update(a=c, b=c) # update items in state\n",
|
|
" state = state.register(c=c) # add new item to state\n",
|
|
" return state # return state\n",
|
|
"\n",
|
|
"\n",
|
|
"new_state = func(state_a)\n",
|
|
"print(f\"{new_state=}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "2bd732ca",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Save the state (use pickle) as file and load it.\n",
|
|
"state = State(a=1, b=2, c=3, d=4)\n",
|
|
"state.save(\"tutorial_0_state.pkl\")\n",
|
|
"loaded_state = State.load(\"tutorial_0_state.pkl\")\n",
|
|
"print(f\"{loaded_state=}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "846b5374",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Objects in TensorNEAT\n",
|
|
"\n",
|
|
"In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. \n",
|
|
"\n",
|
|
"In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.\n",
|
|
"\n",
|
|
"For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "dac2bee6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorneat.genome import DefaultGenome\n",
|
|
"\n",
|
|
"genome = DefaultGenome(\n",
|
|
" num_inputs=3,\n",
|
|
" num_outputs=1,\n",
|
|
" max_nodes=5,\n",
|
|
" max_conns=5,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2c196cc4",
|
|
"metadata": {},
|
|
"source": [
|
|
"`genome` only stores functions that define the operation of the genome in the NEAT algorithm. \n",
|
|
"\n",
|
|
"To create a genome that can participate in calculation, we need to do following things."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "7c597c04",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"nodes=Array([[ 0. , 0.5097862 , 1. , 0. , 0. ],\n",
|
|
" [ 1. , 0.9807121 , 1. , 0. , 0. ],\n",
|
|
" [ 2. , -0.8425486 , 1. , 0. , 0. ],\n",
|
|
" [ 3. , -0.53765106, 1. , 0. , 0. ],\n",
|
|
" [ nan, nan, nan, nan, nan]], dtype=float32, weak_type=True)\n",
|
|
"conns=Array([[0. , 3. , 0.785558 ],\n",
|
|
" [1. , 3. , 2.3734226 ],\n",
|
|
" [2. , 3. , 0.07902155],\n",
|
|
" [ nan, nan, nan],\n",
|
|
" [ nan, nan, nan]], dtype=float32, weak_type=True)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# setup the genome, let the genome class store some useful information in State\n",
|
|
"state = genome.setup()\n",
|
|
"\n",
|
|
"# create a new genome\n",
|
|
"randkey = jax.random.key(0)\n",
|
|
"nodes, conns = genome.initialize(state, randkey)\n",
|
|
"print(f\"{nodes=}\")\n",
|
|
"print(f\"{conns=}\")\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "1d8ae137",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"outputs=Array([5.231817], dtype=float32, weak_type=True)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# calculate\n",
|
|
"inputs = jax.numpy.array([1, 2, 3])\n",
|
|
"\n",
|
|
"transformed = genome.transform(state, nodes, conns)\n",
|
|
"outputs = genome.forward(state, transformed, inputs)\n",
|
|
"\n",
|
|
"print(f\"{outputs=}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7f01ed5d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion\n",
|
|
"1. TensorNEAT use functional programming paradiam.\n",
|
|
"2. TensorNEAT provides `State` to manage data.\n",
|
|
"3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "jax_env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|