489 lines
19 KiB
Plaintext
489 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax, jax.numpy as jnp\n",
|
|
"\n",
|
|
"from algorithm.neat import *\n",
|
|
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
|
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
|
"from utils.graph import topological_sort_python\n",
|
|
"from tensorneat.utils import ACT, AGG\n",
|
|
"\n",
|
|
"import numpy as np"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:58:32.064076300Z",
|
|
"start_time": "2024-06-12T22:58:31.208435400Z"
|
|
}
|
|
},
|
|
"id": "9531a569d9ecf774"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"outputs": [],
|
|
"source": [
|
|
"genome = AdvanceInitialize(\n",
|
|
" num_inputs=20,\n",
|
|
" num_outputs=1,\n",
|
|
" hidden_cnt=2,\n",
|
|
" max_nodes=30,\n",
|
|
" max_conns=50,\n",
|
|
" node_gene=NodeGeneWithoutResponse(\n",
|
|
" activation_default= ACT.identity,\n",
|
|
" aggregation_default=AGG.sum,\n",
|
|
" # activation_options=(ACT.tanh, ACT.sigmoid, ACT.identity, ACT.clamped),\n",
|
|
" activation_options=( ACT.identity, ),\n",
|
|
" aggregation_options=(AGG.sum,),\n",
|
|
" ),\n",
|
|
" # output_transform=jnp.tanh,\n",
|
|
")\n",
|
|
"\n",
|
|
"state = genome.setup()\n",
|
|
"\n",
|
|
"randkey = jax.random.PRNGKey(42)\n",
|
|
"nodes, conns = genome.initialize(state, randkey)\n",
|
|
"\n",
|
|
"network = genome.network_dict(state, nodes, conns)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:06.346267500Z",
|
|
"start_time": "2024-06-12T23:06:06.121268500Z"
|
|
}
|
|
},
|
|
"id": "4013c9f9d5472eb7"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"outputs": [],
|
|
"source": [
|
|
"import sympy as sp\n",
|
|
"\n",
|
|
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n",
|
|
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n",
|
|
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:06.698824100Z",
|
|
"start_time": "2024-06-12T23:06:06.688855300Z"
|
|
}
|
|
},
|
|
"id": "addea793fc002900"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([-0.10080967, -2.373122 , -0.12224621, 1.0417817 , 0.26311624,\n -0.04573117, 0.5329444 , 1.9844177 , -0.5471916 , -3.0961084 ,\n 0.07978257, -1.0657575 , -1.6740963 , 1.2435746 , -0.5811825 ,\n 0.8970058 , -0.4379712 , 0.9084878 , -1.0984142 , 0.33063456], dtype=float32),\n dtype('float32'))"
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"random_inputs = np.random.randn(20).astype(np.float32)\n",
|
|
"random_inputs = jax.device_put(random_inputs)\n",
|
|
"random_inputs, random_inputs.dtype"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:13.317466Z",
|
|
"start_time": "2024-06-12T23:06:13.251298500Z"
|
|
}
|
|
},
|
|
"id": "3aa7c874f3a5743f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([-3.769288], dtype=float32, weak_type=True)"
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"transformed = genome.transform(state, nodes, conns)\n",
|
|
"res = genome.forward(state, transformed, random_inputs)\n",
|
|
"res"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:13.734130200Z",
|
|
"start_time": "2024-06-12T23:06:13.530130300Z"
|
|
}
|
|
},
|
|
"id": "fe3449a5bc688bc3"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([-3.7692878], dtype=float32),\n Array([-3.769288], dtype=float32, weak_type=True))"
|
|
},
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res1 = jnp.array(jax_forward_func(random_inputs))\n",
|
|
"res2 = genome.forward(state, transformed, random_inputs)\n",
|
|
"res1, res2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:14.069158700Z",
|
|
"start_time": "2024-06-12T23:06:13.857130700Z"
|
|
}
|
|
},
|
|
"id": "a874d434509f1092"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "False"
|
|
},
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"all(res1 == res)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T23:06:14.302366300Z",
|
|
"start_time": "2024-06-12T23:06:14.276336800Z"
|
|
}
|
|
},
|
|
"id": "d226e5bd6e2d44d6"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"outputs": [],
|
|
"source": [
|
|
"random_inputs = np.random.randn(1000).astype(np.float32)\n",
|
|
"random_inputs = jax.device_put(random_inputs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:13.519243800Z",
|
|
"start_time": "2024-06-12T22:36:13.502174700Z"
|
|
}
|
|
},
|
|
"id": "90dd091f91df6fa2"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"outputs": [],
|
|
"source": [
|
|
"res1 = 1.243123123 + random_inputs * 1.12413243123123"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:13.898386100Z",
|
|
"start_time": "2024-06-12T22:36:13.886386800Z"
|
|
}
|
|
},
|
|
"id": "b35d2c01f50071c2"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"outputs": [],
|
|
"source": [
|
|
"res2 = random_inputs * 1.12413243123123 + 1.243123123"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:14.245129600Z",
|
|
"start_time": "2024-06-12T22:36:14.220718100Z"
|
|
}
|
|
},
|
|
"id": "e69752c00bd32361"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 76,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array([ True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True, True, True, True, True, True, True, True, True,\n True], dtype=bool)"
|
|
},
|
|
"execution_count": 76,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res1 == res2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:14.573137100Z",
|
|
"start_time": "2024-06-12T22:36:14.551029Z"
|
|
}
|
|
},
|
|
"id": "78b9e00669100231"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array(1251.1074, dtype=float32), Array(1251.1078, dtype=float32))"
|
|
},
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"resres1 = 0.\n",
|
|
"for i in range(1000):\n",
|
|
" resres1 += res1[i]\n",
|
|
"resres2 = jnp.sum(res2)\n",
|
|
"resres1, resres2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:15.688800500Z",
|
|
"start_time": "2024-06-12T22:36:15.088955300Z"
|
|
}
|
|
},
|
|
"id": "91d739b6b1520363"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 78,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array(False, dtype=bool)"
|
|
},
|
|
"execution_count": 78,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"resres1 == resres2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:36:16.917702Z",
|
|
"start_time": "2024-06-12T22:36:16.906916800Z"
|
|
}
|
|
},
|
|
"id": "60502ba8f400bb60"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 66,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "(Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32),\n Array([-0.0913986 , 0.6177338 , 3.704111 , 1.067648 , 3.5810733 ,\n 0.3716032 , -0.10655618, 1.3503847 , 0.97305036, 0.7711922 ], dtype=float32))"
|
|
},
|
|
"execution_count": 66,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res1, res2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:35:50.915413100Z",
|
|
"start_time": "2024-06-12T22:35:50.868608Z"
|
|
}
|
|
},
|
|
"id": "c1c092879ca33c3c"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"outputs": [],
|
|
"source": [
|
|
"real = 10\n",
|
|
"full = 50000\n",
|
|
"random_inputs = np.random.randn(real).astype(np.float32)\n",
|
|
"random_inputs = jax.device_put(random_inputs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:55:02.540229100Z",
|
|
"start_time": "2024-06-12T22:55:02.530229200Z"
|
|
}
|
|
},
|
|
"id": "5e1e7555937f1690"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"outputs": [],
|
|
"source": [
|
|
"all_nans = jnp.full((full,), jnp.nan)\n",
|
|
"large = all_nans.at[:real].set(random_inputs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:55:02.971772100Z",
|
|
"start_time": "2024-06-12T22:55:02.961150Z"
|
|
}
|
|
},
|
|
"id": "5ea2c2051880ca1e"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array(-5.8886395, dtype=float32)"
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res1 = jnp.sum(large, where=~jnp.isnan(large))\n",
|
|
"res1"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:55:03.381951100Z",
|
|
"start_time": "2024-06-12T22:55:03.370952400Z"
|
|
}
|
|
},
|
|
"id": "9b4baa99710fad1"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array(-5.8886395, dtype=float32)"
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res2 = jnp.sum(random_inputs)\n",
|
|
"res2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:55:03.781919300Z",
|
|
"start_time": "2024-06-12T22:55:03.770892100Z"
|
|
}
|
|
},
|
|
"id": "9b2897759c90b7c5"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "Array(True, dtype=bool)"
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res1 == res2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-12T22:55:04.328871600Z",
|
|
"start_time": "2024-06-12T22:55:04.314871Z"
|
|
}
|
|
},
|
|
"id": "e312213166610144"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "7e9ca9fa15e3c401"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|