fix tiny warning in tutorial
This commit is contained in:
@@ -50,7 +50,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 90,
|
"execution_count": 20,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -83,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 91,
|
"execution_count": 21,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -134,7 +134,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 92,
|
"execution_count": 22,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -144,36 +144,36 @@
|
|||||||
"initializing\n",
|
"initializing\n",
|
||||||
"initializing finished\n",
|
"initializing finished\n",
|
||||||
"start compile\n",
|
"start compile\n",
|
||||||
"compile finished, cost time: 10.318278s\n",
|
"compile finished, cost time: 10.663079s\n",
|
||||||
"Generation: 1, Cost time: 45.01ms\n",
|
"Generation: 1, Cost time: 64.31ms\n",
|
||||||
" \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
|
" \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\tnode counts: max: 4, min: 3, mean: 3.10\n",
|
"\tnode counts: max: 4, min: 3, mean: 3.10\n",
|
||||||
" \tconn counts: max: 3, min: 0, mean: 1.88\n",
|
" \tconn counts: max: 3, min: 0, mean: 1.88\n",
|
||||||
" \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
|
" \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Generation: 2, Cost time: 53.77ms\n",
|
"Generation: 2, Cost time: 65.52ms\n",
|
||||||
" \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
|
" \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\tnode counts: max: 5, min: 3, mean: 3.17\n",
|
"\tnode counts: max: 5, min: 3, mean: 3.17\n",
|
||||||
" \tconn counts: max: 4, min: 0, mean: 1.87\n",
|
" \tconn counts: max: 4, min: 0, mean: 1.87\n",
|
||||||
" \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
|
" \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Generation: 3, Cost time: 21.86ms\n",
|
"Generation: 3, Cost time: 59.10ms\n",
|
||||||
" \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
|
" \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\tnode counts: max: 6, min: 3, mean: 3.49\n",
|
"\tnode counts: max: 6, min: 3, mean: 3.49\n",
|
||||||
" \tconn counts: max: 6, min: 0, mean: 2.47\n",
|
" \tconn counts: max: 6, min: 0, mean: 2.47\n",
|
||||||
" \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
|
" \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Generation: 4, Cost time: 24.30ms\n",
|
"Generation: 4, Cost time: 34.24ms\n",
|
||||||
" \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
|
" \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\tnode counts: max: 6, min: 3, mean: 3.76\n",
|
"\tnode counts: max: 6, min: 3, mean: 3.76\n",
|
||||||
" \tconn counts: max: 6, min: 0, mean: 2.66\n",
|
" \tconn counts: max: 6, min: 0, mean: 2.66\n",
|
||||||
" \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
|
" \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Generation: 5, Cost time: 27.68ms\n",
|
"Generation: 5, Cost time: 20.36ms\n",
|
||||||
" \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
|
" \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\tnode counts: max: 6, min: 3, mean: 3.94\n",
|
"\tnode counts: max: 6, min: 3, mean: 3.94\n",
|
||||||
@@ -246,7 +246,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 93,
|
"execution_count": 23,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -272,7 +272,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 94,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -312,7 +312,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 95,
|
"execution_count": 25,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -332,7 +332,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 96,
|
"execution_count": 26,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -347,14 +347,14 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n"
|
"100%|██████████| 100/100 [00:04<00:00, 22.38it/s]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Generation 0: best fitness: 384.0\n",
|
"Generation 0: best fitness: 493.0\n",
|
||||||
"Generation 1: evaluating population...\n"
|
"Generation 1: evaluating population...\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -362,7 +362,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n"
|
"100%|██████████| 100/100 [00:17<00:00, 5.62it/s]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -405,7 +405,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 97,
|
"execution_count": 27,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -435,6 +435,16 @@
|
|||||||
" obs, reward, terminated, truncated = [], [], [], []\n",
|
" obs, reward, terminated, truncated = [], [], [], []\n",
|
||||||
" # we still need to step the envs one by one\n",
|
" # we still need to step the envs one by one\n",
|
||||||
" for i in range(POP_SIZE):\n",
|
" for i in range(POP_SIZE):\n",
|
||||||
|
" if episode_over[i]:\n",
|
||||||
|
" # is already terminated\n",
|
||||||
|
" # append zeros to keep the shape\n",
|
||||||
|
" obs.append(np.zeros(4))\n",
|
||||||
|
" reward.append(0.0)\n",
|
||||||
|
" terminated.append(True)\n",
|
||||||
|
" truncated.append(False)\n",
|
||||||
|
" continue\n",
|
||||||
|
" else:\n",
|
||||||
|
" # step the env\n",
|
||||||
" obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
|
" obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n",
|
||||||
" obs.append(obs_)\n",
|
" obs.append(obs_)\n",
|
||||||
" reward.append(reward_)\n",
|
" reward.append(reward_)\n",
|
||||||
@@ -455,34 +465,22 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 98,
|
"execution_count": 28,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n",
|
"100%|██████████| 100/100 [00:14<00:00, 6.87it/s]\n"
|
||||||
" logger.warn(\n",
|
|
||||||
"100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"slow_time=14.704506635665894, fast_time=1.5114572048187256\n"
|
"slow_time=14.562041997909546, fast_time=1.134758710861206\n"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"(14.704506635665894, 1.5114572048187256)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 98,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -499,8 +497,7 @@
|
|||||||
"fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
"fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n",
|
||||||
"fast_time = time.time() - time_tic\n",
|
"fast_time = time.time() - time_tic\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"{slow_time=}, {fast_time=}\")\n",
|
"print(f\"{slow_time=}, {fast_time=}\")"
|
||||||
"slow_time, fast_time"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user