diff --git a/tutorials/tutorial-02-Usage.ipynb b/tutorials/tutorial-02-Usage.ipynb index a35d497..2e1a59b 100644 --- a/tutorials/tutorial-02-Usage.ipynb +++ b/tutorials/tutorial-02-Usage.ipynb @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -144,36 +144,36 @@ "initializing\n", "initializing finished\n", "start compile\n", - "compile finished, cost time: 10.318278s\n", - "Generation: 1, Cost time: 45.01ms\n", + "compile finished, cost time: 10.663079s\n", + "Generation: 1, Cost time: 64.31ms\n", " \tfitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848\n", "\n", "\tnode counts: max: 4, min: 3, mean: 3.10\n", " \tconn counts: max: 3, min: 0, mean: 1.88\n", " \tspecies: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]\n", "\n", - "Generation: 2, Cost time: 53.77ms\n", + "Generation: 2, Cost time: 65.52ms\n", " \tfitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452\n", "\n", "\tnode counts: max: 5, min: 3, mean: 3.17\n", " \tconn counts: max: 4, min: 0, mean: 1.87\n", " \tspecies: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]\n", "\n", - "Generation: 3, Cost time: 21.86ms\n", + "Generation: 3, Cost time: 59.10ms\n", " \tfitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.49\n", " \tconn counts: max: 6, min: 0, mean: 2.47\n", " \tspecies: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]\n", "\n", - "Generation: 4, Cost time: 24.30ms\n", + "Generation: 4, Cost time: 34.24ms\n", " \tfitness: valid cnt: 996, max: -0.0056, min: -158.4687, mean: -1.0448, std: 9.8865\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.76\n", " \tconn counts: max: 6, min: 0, mean: 2.66\n", " \tspecies: 20, [259, 96, 87, 19, 100, 9, 54, 84, 27, 52, 45, 35, 36, 3, 10, 17, 16, 3, 6, 42]\n", "\n", - "Generation: 5, Cost time: 27.68ms\n", + "Generation: 5, Cost time: 20.36ms\n", " \tfitness: valid cnt: 993, max: -0.0055, min: -4954.1787, mean: -7.3952, std: 157.9562\n", "\n", "\tnode counts: max: 6, min: 3, mean: 3.94\n", @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -272,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -312,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -332,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -347,14 +347,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 100/100 [00:04<00:00, 23.15it/s]\n" + "100%|██████████| 100/100 [00:04<00:00, 22.38it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Generation 0: best fitness: 384.0\n", + "Generation 0: best fitness: 493.0\n", "Generation 1: evaluating population...\n" ] }, @@ -362,7 +362,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 100/100 [00:13<00:00, 7.37it/s]\n" + "100%|██████████| 100/100 [00:17<00:00, 5.62it/s]\n" ] }, { @@ -405,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -435,11 +435,21 @@ " obs, reward, terminated, truncated = [], [], [], []\n", " # we still need to step the envs one by one\n", " for i in range(POP_SIZE):\n", - " obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n", - " obs.append(obs_)\n", - " reward.append(reward_)\n", - " terminated.append(terminated_)\n", - " truncated.append(truncated_)\n", + " if episode_over[i]:\n", + " # is already terminated\n", + " # append zeros to keep the shape\n", + " obs.append(np.zeros(4))\n", + " reward.append(0.0)\n", + " terminated.append(True)\n", + " truncated.append(False)\n", + " continue\n", + " else:\n", + " # step the env\n", + " obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])\n", + " obs.append(obs_)\n", + " reward.append(reward_)\n", + " terminated.append(terminated_)\n", + " truncated.append(truncated_)\n", "\n", " pop_observation = np.asarray(obs)\n", " pop_reward = np.asarray(reward)\n", @@ -455,34 +465,22 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/wls-laptop-ubuntu/miniconda3/envs/jax_env/lib/python3.10/site-packages/gymnasium/envs/classic_control/cartpole.py:180: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned terminated = True. You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior.\u001b[0m\n", - " logger.warn(\n", - "100%|██████████| 100/100 [00:14<00:00, 6.80it/s]\n" + "100%|██████████| 100/100 [00:14<00:00, 6.87it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "slow_time=14.704506635665894, fast_time=1.5114572048187256\n" + "slow_time=14.562041997909546, fast_time=1.134758710861206\n" ] - }, - { - "data": { - "text/plain": [ - "(14.704506635665894, 1.5114572048187256)" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -499,8 +497,7 @@ "fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)\n", "fast_time = time.time() - time_tic\n", "\n", - "print(f\"{slow_time=}, {fast_time=}\")\n", - "slow_time, fast_time" + "print(f\"{slow_time=}, {fast_time=}\")" ] } ],