In [1]:
%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
sns.set_style("darkgrid")

In [3]:
data_dqn_best = pd.read_pickle("results/grid_search-2000k_steps-256_hidden-batch_size_128-zero_inv_rew/test_results_best-pandas.pickle")

In [4]:
data_dqn_best.sort_values(by=[("dqn", "max_highest_value")], ascending=False)


Out[4]:
dqn random
mean_episode_score max_episode_score mean_highest_value max_highest_value mean_nb_env_steps max_nb_env_steps mean_nb_env_steps_valid mean_nb_env_steps_invalid max_nb_env_steps_invalid mean_episode_score max_episode_score mean_highest_value max_highest_value mean_nb_env_steps max_nb_env_steps mean_nb_env_steps_valid mean_nb_env_steps_invalid max_nb_env_steps_invalid
target_model_update gamma policy
100.000 0.80 annealed_boltzmann 0 1442.60 5796 120.32 512 455.78 1937 623 253.44 1759 1106.60 3096 110.72 256 192.71 448 346 42.82 107
0.99 annealed_boltzmann 0 1613.04 4872 135.36 512 802.36 3582 1003 555.75 2802 1134.40 2788 114.88 256 191.57 452 372 41.78 144
1000.000 0.99 annealed_boltzmann 0 1467.68 5068 129.20 512 546.14 7757 671 319.32 7530 1046.92 2796 104.32 256 184.53 424 325 39.75 157
10000.000 0.90 annealed_boltzmann 0 1555.44 5340 133.12 512 782.86 3251 543 572.79 2805 1106.84 2436 107.52 256 198.06 475 341 45.40 134
10.000 0.90 annealed_boltzmann 0 1542.88 5328 127.84 512 559.53 3684 661 345.53 3546 1151.84 3208 112.32 256 203.94 483 356 46.84 144
1000.000 0.90 annealed_boltzmann 0 1221.64 5044 109.76 512 292.38 801 488 114.63 497 1124.56 3440 109.12 256 205.09 589 424 49.67 165
100.000 0.90 boltzmann 0 1019.16 2788 96.48 256 197.76 473 320 54.54 164 1096.28 2892 107.68 256 199.91 524 381 47.95 169
eps_greedy 0 1065.64 3040 100.32 256 711.82 2048 654 465.27 1450 1211.92 3352 118.72 256 201.24 487 364 45.36 136
annealed_eps_greedy 0 1066.08 3032 98.40 256 1666.23 6048 2607 974.43 3441 1067.24 2752 107.20 256 194.92 416 295 44.44 121
annealed_boltzmann 0 1018.96 2876 98.24 256 232.16 756 463 75.50 293 1104.00 3004 108.80 256 188.31 460 342 41.08 118
1000.000 0.90 annealed_eps_greedy 0 871.76 3096 84.96 256 1544.55 7565 2475 1071.68 5090 1072.88 3144 104.00 256 192.76 468 356 43.11 129
100.000 0.99 annealed_eps_greedy 0 1124.12 2356 108.64 256 1430.61 3536 1761 884.62 2087 1121.68 2456 108.48 256 198.10 412 288 45.16 131
boltzmann 0 1094.84 2804 105.12 256 196.84 417 314 49.96 142 1202.80 3124 116.80 256 211.38 482 366 47.09 116
eps_greedy 0 1136.16 2656 108.16 256 927.11 3072 1139 626.64 1933 1098.20 3016 105.28 256 199.15 517 371 46.48 146
1000.000 0.80 annealed_boltzmann 0 1381.76 3364 115.36 256 330.50 860 505 124.82 366 1073.00 2828 102.56 256 191.84 462 332 43.03 130
annealed_eps_greedy 0 1147.12 3444 107.36 256 1256.45 4999 1413 948.71 3586 1128.32 3104 112.32 256 197.80 488 367 44.34 135
boltzmann 0 1028.00 3012 98.88 256 223.40 622 497 51.74 167 938.80 2572 92.80 256 177.10 445 329 40.99 118
100.000 0.80 eps_greedy 0 1117.16 2664 106.40 256 782.60 3021 1526 441.93 1495 991.12 3476 95.68 256 174.40 514 413 37.08 113
1000.000 0.80 eps_greedy 0 956.20 2352 92.16 256 615.64 2013 975 327.40 1165 1032.24 3340 98.56 256 185.33 505 348 40.45 157
0.90 boltzmann 0 1112.88 3160 104.80 256 229.44 708 459 67.81 249 1183.00 3256 114.08 256 198.94 488 358 43.90 130
100.000 0.80 annealed_eps_greedy 0 1097.84 2780 104.32 256 1402.65 6455 2141 1047.73 4314 1100.68 2376 110.40 256 193.19 465 354 42.21 126
1000.000 0.90 eps_greedy 0 1027.32 2692 95.04 256 918.84 3297 1148 628.04 2149 1072.24 3012 104.96 256 187.30 435 348 40.60 121
0.99 annealed_eps_greedy 0 1239.92 3500 114.88 256 1422.55 4035 1438 1025.24 2699 1070.88 3232 103.20 256 187.33 441 346 41.31 99
boltzmann 0 1079.44 2368 105.28 256 301.06 1019 681 95.15 338 1034.20 2512 98.08 256 188.70 408 310 42.69 121
eps_greedy 0 1225.84 3396 111.84 256 770.54 2305 854 505.79 1451 1093.08 2896 103.20 256 193.68 413 314 43.46 109
10000.000 0.80 annealed_boltzmann 0 853.04 2280 82.08 256 472.17 3369 1407 235.27 1962 961.56 2360 97.92 256 176.54 352 272 38.95 98
annealed_eps_greedy 0 1153.32 2748 110.24 256 1739.61 4735 1671 1270.39 3064 1144.88 2824 112.64 256 203.69 453 328 47.33 146
boltzmann 0 1023.52 2796 100.32 256 186.09 513 365 44.23 148 1164.60 2748 113.92 256 204.37 494 386 46.16 125
eps_greedy 0 1145.44 2480 105.60 256 894.37 2317 808 600.41 1509 1113.12 2732 112.00 256 193.31 410 304 43.01 134
0.90 annealed_eps_greedy 0 1037.32 2772 95.36 256 1114.11 3864 1083 816.74 2827 1052.64 2784 101.44 256 185.68 429 326 40.80 106
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
0.010 0.80 annealed_boltzmann 0 1236.64 3140 103.36 256 868.14 4511 629 656.56 3918 1093.40 3028 107.04 256 190.39 467 347 43.42 142
0.001 0.99 eps_greedy 0 1115.24 3032 103.68 256 773.34 2129 807 525.12 1329 1169.32 2912 115.52 256 201.99 478 358 45.16 120
boltzmann 0 1086.36 3080 107.84 256 198.17 504 364 49.88 174 1079.72 3036 105.60 256 186.90 422 335 41.38 130
annealed_boltzmann 0 1358.52 4028 114.56 256 1925.68 37636 835 1679.23 36994 1174.92 2576 116.80 256 201.21 452 341 44.91 144
0.010 0.90 annealed_eps_greedy 0 1107.92 2804 100.16 256 1590.49 4710 1417 1166.22 3293 1094.28 3032 105.60 256 191.83 620 414 43.32 217
0.001 0.90 eps_greedy 0 1159.08 2692 107.20 256 838.08 2140 731 566.58 1409 1134.64 3100 108.48 256 195.76 407 308 42.89 99
boltzmann 0 1015.44 2740 98.72 256 204.86 579 379 58.73 200 1023.72 2808 99.52 256 181.36 458 308 40.47 150
annealed_eps_greedy 0 1031.56 3160 93.76 256 1449.19 4507 1489 1057.72 3018 1066.96 2380 105.12 256 186.70 433 321 41.92 112
annealed_boltzmann 0 1372.92 3308 118.72 256 281.97 711 390 104.24 561 1025.84 3064 101.28 256 181.41 386 319 40.72 118
0.80 eps_greedy 0 1165.80 3140 106.56 256 1067.61 2997 1786 564.46 1338 1098.52 4560 108.00 512 186.16 416 335 39.34 124
boltzmann 0 1057.72 2524 106.24 256 190.95 386 286 46.03 142 1140.80 2764 111.04 256 196.16 470 349 42.20 136
0.010 0.90 annealed_boltzmann 0 1151.92 3388 108.00 256 233.06 708 454 68.00 254 1077.92 3160 108.16 256 195.94 529 393 45.15 136
boltzmann 0 1074.72 3280 104.00 256 265.94 1295 679 95.24 616 1082.76 2608 104.16 256 198.91 494 374 45.64 134
10.000 0.99 eps_greedy 0 1091.44 2700 102.72 256 919.61 2939 1020 619.45 1940 1133.56 4492 112.96 512 195.85 613 476 43.66 137
0.80 eps_greedy 0 1089.28 2460 102.72 256 657.13 2062 651 435.11 1411 1087.00 3000 105.60 256 197.68 659 441 46.55 218
0.99 boltzmann 0 1025.56 2628 99.68 256 188.05 382 286 41.20 96 1080.92 2488 105.28 256 193.16 423 305 43.29 122
annealed_eps_greedy 0 1263.12 3568 117.60 256 2194.41 7873 4385 1203.53 3642 1157.24 2936 112.64 256 199.68 485 361 44.57 124
annealed_boltzmann 0 1017.84 3256 88.48 256 587.90 10158 1353 328.73 9880 1093.84 2500 110.40 256 198.62 468 340 45.15 141
0.90 eps_greedy 0 877.48 2356 88.32 256 826.53 2661 1537 333.67 1124 1035.20 3276 99.68 256 187.00 543 405 42.74 138
boltzmann 0 1081.60 2824 105.60 256 212.00 585 443 49.46 148 1062.00 2912 105.12 256 193.78 523 406 44.71 121
annealed_eps_greedy 0 1104.24 2456 104.00 256 1972.58 8701 4619 1064.24 4082 1145.56 2992 110.72 256 203.91 496 378 47.13 139
0.80 boltzmann 0 1192.16 2796 114.24 256 226.48 509 354 62.03 171 1066.08 3324 103.04 256 190.18 448 347 42.67 139
0.010 0.90 eps_greedy 0 802.84 2468 77.12 256 805.97 2904 1117 501.53 1787 1123.76 2376 108.16 256 198.07 356 275 44.33 110
10.000 0.80 annealed_eps_greedy 0 1164.68 2796 109.28 256 1090.02 3281 1108 818.28 2320 1049.52 3244 103.68 256 183.87 469 366 40.82 109
annealed_boltzmann 0 829.60 2304 79.04 256 871.15 5406 2274 470.85 3132 1088.40 3084 106.40 256 192.12 472 349 42.77 123
0.010 0.99 eps_greedy 0 985.80 3244 93.44 256 650.98 2428 1503 371.78 1221 1090.52 3376 107.20 256 187.11 475 372 40.49 103
boltzmann 0 957.56 2340 92.80 256 206.11 521 342 60.83 213 1075.20 2460 104.16 256 187.47 422 316 42.37 106
annealed_eps_greedy 0 952.48 2764 90.40 256 1932.27 7424 2568 1325.85 4856 1131.60 2600 111.04 256 198.91 421 308 44.77 133
annealed_boltzmann 0 968.88 3008 90.24 256 564.38 3867 1547 313.66 2320 1014.16 2436 97.92 256 180.90 385 283 39.70 105
10000.000 0.99 eps_greedy 0 1197.92 3156 109.76 256 679.75 2317 714 446.49 1603 1131.00 3096 112.96 256 193.26 556 389 43.30 167

72 rows × 18 columns


In [43]:
best_agents = [(100, 0.80, "annealed_boltzmann"),
               (100, 0.99, "annealed_boltzmann"),
               (1000, 0.99, "annealed_boltzmann"),
               (10000, 0.90, "annealed_boltzmann"),
               (10, 0.90, "annealed_boltzmann"),
               (1000, 0.90, "annealed_boltzmann")]
logger_file = "results/grid_search-2000k_steps-256_hidden-batch_size_128-zero_inv_rew/logger-train-dqn-2000k_steps-update_{update}-gamma_{gamma}-policy_{policy}.pickle"
colors = ["b", "g", "r", "c", "m", "y"]
plt.figure(figsize=(15, 10))
plt.suptitle("DQN progress, 2 millions steps\nScore from each 100-th episode\nLegend: target model update, gamma, policy", fontsize=14, fontweight="bold")
plt.xlabel("i-th episode", fontsize=14)
plt.ylabel("episode score", fontsize=14)
for agent, color in zip(best_agents, colors):
    agent_log = pd.read_pickle(logger_file.format(update=agent[0], gamma=agent[1], policy=agent[2]))
    episode_scores = [x["episode_score"] for x in agent_log.episodes if "episode_score" in x]
    plt.plot(list(range(0, len(agent_log.episodes), 100)) + [len(agent_log.episodes) - 1],
             episode_scores[::100] + [episode_scores[-1]],
             color,
             label="{}, {}, {}".format(*agent),
             linewidth=1)
plt.legend(fontsize=15)
plt.savefig("images/play_dqn.png", dpi=800)



In [ ]: