add backend="jax" to sympy module
This commit is contained in:
@@ -11,13 +11,15 @@
|
||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"from utils.graph import topological_sort_python\n",
|
||||
"from utils import Act, Agg"
|
||||
"from utils import Act, Agg\n",
|
||||
"\n",
|
||||
"import numpy as np"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:46.886073700Z",
|
||||
"start_time": "2024-06-12T11:35:46.042288800Z"
|
||||
"end_time": "2024-06-12T21:48:58.065855900Z",
|
||||
"start_time": "2024-06-12T21:48:57.292767Z"
|
||||
}
|
||||
},
|
||||
"id": "9531a569d9ecf774"
|
||||
@@ -29,8 +31,8 @@
|
||||
"source": [
|
||||
"genome = AdvanceInitialize(\n",
|
||||
" num_inputs=3,\n",
|
||||
" num_outputs=1,\n",
|
||||
" hidden_cnt=1,\n",
|
||||
" num_outputs=3,\n",
|
||||
" hidden_cnt=2,\n",
|
||||
" max_nodes=50,\n",
|
||||
" max_conns=500,\n",
|
||||
" node_gene=NodeGeneWithoutResponse(\n",
|
||||
@@ -38,7 +40,8 @@
|
||||
" aggregation_default=Agg.sum,\n",
|
||||
" # activation_options=(Act.tanh,),\n",
|
||||
" aggregation_options=(Agg.sum,),\n",
|
||||
" )\n",
|
||||
" ),\n",
|
||||
" output_transform=jnp.tanh,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"state = genome.setup()\n",
|
||||
@@ -51,8 +54,8 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.274062400Z",
|
||||
"start_time": "2024-06-12T11:35:46.892042200Z"
|
||||
"end_time": "2024-06-12T21:49:03.858545Z",
|
||||
"start_time": "2024-06-12T21:48:58.071859800Z"
|
||||
}
|
||||
},
|
||||
"id": "4013c9f9d5472eb7"
|
||||
@@ -63,7 +66,7 @@
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
|
||||
"text/plain": "{'nodes': {0: {'idx': 0,\n 'bias': array(0.22059791, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 1: {'idx': 1,\n 'bias': array(0.7715081, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 2: {'idx': 2,\n 'bias': array(1.1184921, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 3: {'idx': 3,\n 'bias': array(0.6967973, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 4: {'idx': 4,\n 'bias': array(0.85948837, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 5: {'idx': 5,\n 'bias': array(0.19332138, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 6: {'idx': 6,\n 'bias': array(-0.31763914, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 7: {'idx': 7,\n 'bias': array(0.05656302, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'}},\n 'conns': {(0, 6): {'in': 0,\n 'out': 6,\n 'weight': array(1.6676894, dtype=float32)},\n (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},\n (1, 6): {'in': 1, 'out': 6, 'weight': array(0.10137014, dtype=float32)},\n (1, 7): {'in': 1, 'out': 7, 'weight': array(-0.12093307, dtype=float32)},\n (2, 6): {'in': 2, 'out': 6, 'weight': array(-1.8677292, dtype=float32)},\n (2, 7): {'in': 2, 'out': 7, 'weight': array(-0.4195783, dtype=float32)},\n (6, 3): {'in': 6, 'out': 3, 'weight': array(1.2615877, dtype=float32)},\n (6, 4): {'in': 6, 'out': 4, 'weight': array(-0.27593768, dtype=float32)},\n (6, 5): {'in': 6, 'out': 5, 'weight': array(-0.5819819, dtype=float32)},\n (7, 3): {'in': 7, 'out': 3, 'weight': array(0.59301573, dtype=float32)},\n (7, 4): {'in': 7, 'out': 4, 'weight': array(0.19493186, dtype=float32)},\n (7, 5): {'in': 7, 'out': 5, 'weight': array(0.18183969, dtype=float32)}}}"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
@@ -71,89 +74,72 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sympy as sp\n",
|
||||
"\n",
|
||||
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
|
||||
"output_exprs"
|
||||
"network"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.325161800Z",
|
||||
"start_time": "2024-06-12T11:35:52.282008300Z"
|
||||
"end_time": "2024-06-12T21:49:03.873543600Z",
|
||||
"start_time": "2024-06-12T21:49:03.867543Z"
|
||||
}
|
||||
},
|
||||
"id": "188006cebb04847"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sympy as sp\n",
|
||||
"\n",
|
||||
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n",
|
||||
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n",
|
||||
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T21:50:37.527882500Z",
|
||||
"start_time": "2024-06-12T21:50:37.518559400Z"
|
||||
}
|
||||
},
|
||||
"id": "addea793fc002900"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sp.latex(output_exprs[0]))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.341639700Z",
|
||||
"start_time": "2024-06-12T11:35:52.323163700Z"
|
||||
}
|
||||
},
|
||||
"id": "967cb87e24373f77"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "88eee4db9eb857cd"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[-0.7940936986556304]"
|
||||
"text/plain": "(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"random_inputs = np.random.randn(3)\n",
|
||||
"res = forward_func(random_inputs)\n",
|
||||
"res "
|
||||
"random_inputs = np.random.randn(3).astype(np.float32)\n",
|
||||
"random_inputs, random_inputs.dtype"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:52.342638Z",
|
||||
"start_time": "2024-06-12T11:35:52.330160600Z"
|
||||
"end_time": "2024-06-12T21:50:38.178769100Z",
|
||||
"start_time": "2024-06-12T21:50:38.155744Z"
|
||||
}
|
||||
},
|
||||
"id": "c5581201d990ba1c"
|
||||
"id": "3aa7c874f3a5743f"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 13,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
|
||||
"text/plain": "Array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32, weak_type=True)"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -166,25 +152,89 @@
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.273851900Z",
|
||||
"start_time": "2024-06-12T11:35:52.384588600Z"
|
||||
"end_time": "2024-06-12T21:50:48.747287900Z",
|
||||
"start_time": "2024-06-12T21:50:48.560675400Z"
|
||||
}
|
||||
},
|
||||
"id": "fe3449a5bc688bc3"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "(array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32),\n array([ 0.9743453 , 0.57646036, -0.3080282 ], dtype=float32),\n array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32))"
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)\n",
|
||||
"res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)\n",
|
||||
"res = np.array(genome.forward(state, transformed, random_inputs))\n",
|
||||
"res1, res2, res"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T11:35:53.274854100Z",
|
||||
"start_time": "2024-06-12T11:35:53.265856700Z"
|
||||
"end_time": "2024-06-12T21:51:15.098948600Z",
|
||||
"start_time": "2024-06-12T21:51:14.908948500Z"
|
||||
}
|
||||
},
|
||||
"id": "174c7dc3d9499f95"
|
||||
"id": "a874d434509f1092"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "(array([ True, True, True]), array([ True, False, True]))"
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res1 == res, res2 == res"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T21:51:25.857465200Z",
|
||||
"start_time": "2024-06-12T21:51:25.851465300Z"
|
||||
}
|
||||
},
|
||||
"id": "d226e5bd6e2d44d6"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "array([False, False, True])"
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-12T21:00:19.851215800Z",
|
||||
"start_time": "2024-06-12T21:00:19.836443700Z"
|
||||
}
|
||||
},
|
||||
"id": "2a36ce6afc59ee8a"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -28,3 +28,7 @@ if __name__ == '__main__':
|
||||
|
||||
print(genome.repr(state, nodes, conns))
|
||||
print(network)
|
||||
|
||||
res = genome.sympy_func(state, network, precision=3)
|
||||
print(res)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user