Files
tensorneat-mend/tensorneat/test/test_flatten.ipynb
2024-05-31 19:43:14 +08:00

234 lines
7.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-05-31T09:01:41.824974900Z",
"start_time": "2024-05-31T09:01:39.138674100Z"
}
},
"outputs": [],
"source": [
"from algorithm.neat.genome import DefaultGenome\n",
"from utils.tools import flatten_conns, unflatten_conns\n",
"import jax, jax.numpy as jnp\n",
"from jax import vmap"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "((5, 5), (5, 4))"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"genome = DefaultGenome(num_inputs=3, num_outputs=1, max_nodes=5, max_conns=5)\n",
"state = genome.setup()\n",
"key = jax.random.PRNGKey(0)\n",
"nodes, conns = genome.initialize(state, key)\n",
"nodes.shape, conns.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-31T09:01:45.179170400Z",
"start_time": "2024-05-31T09:01:41.832976100Z"
}
},
"id": "89fb5cd0e77a028d"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n 0.13149254],\n [ nan, nan, nan, nan,\n 0.02001922],\n [ nan, nan, nan, nan,\n -0.79229796],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.57102853,\n nan]]], dtype=float32, weak_type=True))"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed = genome.transform(state, nodes, conns)\n",
"transformed"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-31T09:01:45.729969500Z",
"start_time": "2024-05-31T09:01:45.178173400Z"
}
},
"id": "aaa88227bbf29936"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "(Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[ 1. , 0. , 4. , 0.13149254],\n [ 1. , 1. , 4. , 0.02001922],\n [ 1. , 2. , 4. , -0.79229796],\n [ 1. , 4. , 3. , -0.57102853],\n [ 1. , nan, nan, nan]], dtype=float32))"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single flatten\n",
"nodes, conns = genome.restore(state, transformed)\n",
"nodes, conns"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-31T09:01:46.660023600Z",
"start_time": "2024-05-31T09:01:45.724970700Z"
}
},
"id": "f2c65de38fdcff8f"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[ 1. , 3. , 0. , 1. , 4. ,\n 0.13149254],\n [ 1. , 3. , 1. , 1. , 4. ,\n 0.02001922],\n [ 1. , 3. , 2. , 1. , 4. ,\n -0.79229796],\n [ 1. , 3. , 4. , 1. , 3. ,\n -0.57102853],\n [ 1. , 3. , nan, 1. , nan,\n nan]], dtype=float32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conns = jnp.insert(conns, obj=3, values=1, axis=1)\n",
"conns"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-31T09:03:35.665080500Z",
"start_time": "2024-05-31T09:03:35.013654700Z"
}
},
"id": "10bcb665c32fb728"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "((3, 10, 5), (3, 10, 4))"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch_flatten\n",
"key = jax.random.PRNGKey(0)\n",
"keys = jax.random.split(key, 3)\n",
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(state, keys)\n",
"pop_nodes.shape, pop_conns.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T11:43:09.287012800Z",
"start_time": "2024-05-30T11:43:09.230179800Z"
}
},
"id": "fe89b178b721d656"
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "(3, 2, 10, 10)"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pop_unflatten = jax.vmap(unflatten_conns)(pop_nodes, pop_conns)\n",
"pop_unflatten.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T11:43:10.004429100Z",
"start_time": "2024-05-30T11:43:09.404949800Z"
}
},
"id": "14bbb257e5ddeab"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "(3, 10, 4)"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flatten = jax.vmap(flatten_conns, in_axes=(0, 0, None))(pop_nodes, pop_unflatten, 10)\n",
"flatten.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T11:43:39.983690700Z",
"start_time": "2024-05-30T11:43:39.208549Z"
}
},
"id": "8e5cdf6140c81dc0"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}