{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-05-30T11:40:55.584592400Z", "start_time": "2024-05-30T11:40:53.016051600Z" } }, "outputs": [], "source": [ "from algorithm.neat.genome import DefaultGenome\n", "from utils.tools import flatten_conns, unflatten_conns\n", "import jax, jax.numpy as jnp\n", "from jax import vmap" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "data": { "text/plain": "((10, 5), (10, 4))" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)\n", "state = genome.setup()\n", "key = jax.random.PRNGKey(0)\n", "nodes, conns = genome.initialize(state, key)\n", "nodes.shape, conns.shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:40:59.021858400Z", "start_time": "2024-05-30T11:40:55.592593400Z" } }, "id": "89fb5cd0e77a028d" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "data": { "text/plain": "(2, 10, 10)" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "unflatten = unflatten_conns(nodes, conns)\n", "unflatten.shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:40:59.472701700Z", "start_time": "2024-05-30T11:40:59.021858400Z" } }, "id": "aaa88227bbf29936" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "(Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32, weak_type=True),\n Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32))" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# single flatten\n", "flatten = flatten_conns(nodes, unflatten, C=10)\n", "conns, flatten" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:41:00.308954100Z", "start_time": "2024-05-30T11:40:59.469541500Z" } }, "id": "f2c65de38fdcff8f" }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "data": { "text/plain": "((3, 10, 5), (3, 10, 4))" }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch_flatten\n", "key = jax.random.PRNGKey(0)\n", "keys = jax.random.split(key, 3)\n", "pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(state, keys)\n", "pop_nodes.shape, pop_conns.shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:43:09.287012800Z", "start_time": "2024-05-30T11:43:09.230179800Z" } }, "id": "fe89b178b721d656" }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "(3, 2, 10, 10)" }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pop_unflatten = jax.vmap(unflatten_conns)(pop_nodes, pop_conns)\n", "pop_unflatten.shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:43:10.004429100Z", "start_time": "2024-05-30T11:43:09.404949800Z" } }, "id": "14bbb257e5ddeab" }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "data": { "text/plain": "(3, 10, 4)" }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flatten = jax.vmap(flatten_conns, in_axes=(0, 0, None))(pop_nodes, pop_unflatten, 10)\n", "flatten.shape" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T11:43:39.983690700Z", "start_time": "2024-05-30T11:43:39.208549Z" } }, "id": "8e5cdf6140c81dc0" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }