212 lines
4.8 KiB
Plaintext
212 lines
4.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax, jax.numpy as jnp\n",
|
|
"\n",
|
|
"from algorithm.neat import *\n",
|
|
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
|
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
|
"from utils.graph import topological_sort_python\n",
|
|
"from utils import Act, Agg"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:46.886073700Z",
|
|
"start_time": "2024-06-12T11:35:46.042288800Z"
|
|
}
|
|
},
|
|
"id": "9531a569d9ecf774"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"genome = AdvanceInitialize(\n",
|
|
" num_inputs=3,\n",
|
|
" num_outputs=1,\n",
|
|
" hidden_cnt=1,\n",
|
|
" max_nodes=50,\n",
|
|
" max_conns=500,\n",
|
|
" node_gene=NodeGeneWithoutResponse(\n",
|
|
" # activation_default=Act.tanh,\n",
|
|
" aggregation_default=Agg.sum,\n",
|
|
" # activation_options=(Act.tanh,),\n",
|
|
" aggregation_options=(Agg.sum,),\n",
|
|
" )\n",
|
|
")\n",
|
|
"\n",
|
|
"state = genome.setup()\n",
|
|
"\n",
|
|
"randkey = jax.random.PRNGKey(42)\n",
|
|
"nodes, conns = genome.initialize(state, randkey)\n",
|
|
"\n",
|
|
"network = genome.network_dict(state, nodes, conns)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:52.274062400Z",
|
|
"start_time": "2024-06-12T11:35:46.892042200Z"
|
|
}
|
|
},
|
|
"id": "4013c9f9d5472eb7"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import sympy as sp\n",
|
|
"\n",
|
|
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
|
|
"output_exprs"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:52.325161800Z",
|
|
"start_time": "2024-06-12T11:35:52.282008300Z"
|
|
}
|
|
},
|
|
"id": "addea793fc002900"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(sp.latex(output_exprs[0]))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:52.341639700Z",
|
|
"start_time": "2024-06-12T11:35:52.323163700Z"
|
|
}
|
|
},
|
|
"id": "967cb87e24373f77"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "88eee4db9eb857cd"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "[-0.7940936986556304]"
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"random_inputs = np.random.randn(3)\n",
|
|
"res = forward_func(random_inputs)\n",
|
|
"res "
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:52.342638Z",
|
|
"start_time": "2024-06-12T11:35:52.330160600Z"
|
|
}
|
|
},
|
|
"id": "c5581201d990ba1c"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"transformed = genome.transform(state, nodes, conns)\n",
|
|
"res = genome.forward(state, transformed, random_inputs)\n",
|
|
"res"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:53.273851900Z",
|
|
"start_time": "2024-06-12T11:35:52.384588600Z"
|
|
}
|
|
},
|
|
"id": "fe3449a5bc688bc3"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T11:35:53.274854100Z",
|
|
"start_time": "2024-06-12T11:35:53.265856700Z"
|
|
}
|
|
},
|
|
"id": "174c7dc3d9499f95"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|