add flatten_conns as an inverse function for unflatten_conns; add "test_flatten.ipynb" as test for them.
208 lines
5.7 KiB
Plaintext
208 lines
5.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "initial_id",
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:40:55.584592400Z",
|
|
"start_time": "2024-05-30T11:40:53.016051600Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from algorithm.neat.genome import DefaultGenome\n",
|
|
"from utils.tools import flatten_conns, unflatten_conns\n",
|
|
"import jax, jax.numpy as jnp\n",
|
|
"from jax import vmap"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "((10, 5), (10, 4))"
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)\n",
|
|
"state = genome.setup()\n",
|
|
"key = jax.random.PRNGKey(0)\n",
|
|
"nodes, conns = genome.initialize(state, key)\n",
|
|
"nodes.shape, conns.shape"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:40:59.021858400Z",
|
|
"start_time": "2024-05-30T11:40:55.592593400Z"
|
|
}
|
|
},
|
|
"id": "89fb5cd0e77a028d"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(2, 10, 10)"
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"unflatten = unflatten_conns(nodes, conns)\n",
|
|
"unflatten.shape"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:40:59.472701700Z",
|
|
"start_time": "2024-05-30T11:40:59.021858400Z"
|
|
}
|
|
},
|
|
"id": "aaa88227bbf29936"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32, weak_type=True),\n Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32))"
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# single flatten\n",
|
|
"flatten = flatten_conns(nodes, unflatten, C=10)\n",
|
|
"conns, flatten"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:41:00.308954100Z",
|
|
"start_time": "2024-05-30T11:40:59.469541500Z"
|
|
}
|
|
},
|
|
"id": "f2c65de38fdcff8f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "((3, 10, 5), (3, 10, 4))"
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# batch_flatten\n",
|
|
"key = jax.random.PRNGKey(0)\n",
|
|
"keys = jax.random.split(key, 3)\n",
|
|
"pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(state, keys)\n",
|
|
"pop_nodes.shape, pop_conns.shape"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:43:09.287012800Z",
|
|
"start_time": "2024-05-30T11:43:09.230179800Z"
|
|
}
|
|
},
|
|
"id": "fe89b178b721d656"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(3, 2, 10, 10)"
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pop_unflatten = jax.vmap(unflatten_conns)(pop_nodes, pop_conns)\n",
|
|
"pop_unflatten.shape"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:43:10.004429100Z",
|
|
"start_time": "2024-05-30T11:43:09.404949800Z"
|
|
}
|
|
},
|
|
"id": "14bbb257e5ddeab"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(3, 10, 4)"
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"flatten = jax.vmap(flatten_conns, in_axes=(0, 0, None))(pop_nodes, pop_unflatten, 10)\n",
|
|
"flatten.shape"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T11:43:39.983690700Z",
|
|
"start_time": "2024-05-30T11:43:39.208549Z"
|
|
}
|
|
},
|
|
"id": "8e5cdf6140c81dc0"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|