In [72]:
%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
In [73]:
sns.set_style("darkgrid")
In [74]:
data_supervised = pd.read_pickle("results/supervised-100_epochs-100k_games-256_hidden-512_min_score/epochs_evaluation_best-supervised-100_epochs-100k_games-256_hidden-512_min_score.pickle")
data_supervised = [data_supervised.ix[i].iloc[-1] for i in range(100)]
data_supervised = pd.DataFrame(data_supervised).reset_index()
del data_supervised["index"]
data_supervised.index.name = "epoch"
data_supervised.to_excel("results/results_supervised.xlsx", engine="xlsxwriter", float_format="%.2f")
data_supervised
Out[74]:
n_games
nn
random
n_games
n_observations
mean_episode_score
max_episode_score
mean_highest_value
max_highest_value
mean_nb_env_steps
max_nb_env_steps
mean_nb_env_steps_valid
mean_nb_env_steps_invalid
max_nb_env_steps_invalid
mean_episode_score
max_episode_score
mean_highest_value
max_highest_value
mean_nb_env_steps
max_nb_env_steps
mean_nb_env_steps_valid
mean_nb_env_steps_invalid
max_nb_env_steps_invalid
epoch
0
2222.0
701961.0
966.16
2872.0
88.80
256.0
195.46
438.0
309.0
53.03
129.0
1139.60
3008.0
109.44
256.0
194.78
514.0
391.0
43.28
128.0
1
2222.0
701961.0
1244.72
3300.0
114.56
256.0
223.10
656.0
480.0
58.53
182.0
1156.04
2900.0
111.04
256.0
206.81
459.0
317.0
47.00
142.0
2
2222.0
701961.0
1201.20
2848.0
115.04
256.0
218.65
492.0
344.0
58.30
148.0
1070.28
2620.0
106.24
256.0
192.50
411.0
301.0
44.69
125.0
3
2222.0
701961.0
1189.12
2872.0
107.52
256.0
223.41
513.0
364.0
60.99
149.0
1114.84
2476.0
112.64
256.0
183.71
389.0
305.0
38.83
92.0
4
2222.0
701961.0
1160.88
4444.0
108.80
512.0
215.80
489.0
355.0
56.63
134.0
1194.72
2636.0
119.04
256.0
205.59
439.0
328.0
46.22
164.0
5
2222.0
701961.0
1098.40
3124.0
102.72
256.0
205.66
554.0
376.0
54.63
178.0
1133.20
3052.0
113.12
256.0
190.97
460.0
361.0
41.19
113.0
6
2222.0
701961.0
1234.04
3048.0
113.12
256.0
233.69
489.0
368.0
64.51
146.0
1212.20
2828.0
117.12
256.0
215.78
526.0
391.0
49.77
140.0
7
2222.0
701961.0
1088.76
2856.0
98.40
256.0
207.93
503.0
366.0
55.05
137.0
1112.72
2640.0
112.00
256.0
191.46
419.0
325.0
40.94
99.0
8
2222.0
701961.0
1079.88
2400.0
100.00
256.0
215.10
473.0
340.0
58.77
133.0
1028.72
2788.0
98.24
256.0
187.25
444.0
316.0
41.71
128.0
9
2222.0
701961.0
1278.64
2956.0
118.88
256.0
233.37
576.0
425.0
63.88
198.0
1172.36
2956.0
113.60
256.0
199.76
474.0
344.0
43.82
146.0
10
2222.0
701961.0
1254.04
4040.0
114.08
256.0
229.07
614.0
436.0
60.65
178.0
1032.44
3116.0
98.56
256.0
187.68
562.0
401.0
42.66
161.0
11
2222.0
701961.0
1207.00
2940.0
110.08
256.0
225.66
477.0
335.0
62.19
142.0
1066.80
3016.0
104.64
256.0
190.29
429.0
327.0
44.75
116.0
12
2222.0
701961.0
1150.24
3048.0
102.08
256.0
216.29
475.0
326.0
58.85
149.0
1128.88
2976.0
110.08
256.0
198.91
516.0
377.0
44.97
139.0
13
2222.0
701961.0
1106.08
3052.0
103.68
256.0
206.28
433.0
322.0
55.67
116.0
1005.40
2324.0
100.64
256.0
176.30
343.0
259.0
39.01
117.0
14
2222.0
701961.0
1232.40
3204.0
114.24
256.0
221.18
526.0
376.0
61.05
185.0
1066.32
2404.0
108.16
256.0
183.17
398.0
281.0
40.69
122.0
15
2222.0
701961.0
1224.28
3076.0
110.72
256.0
227.73
472.0
344.0
61.20
135.0
1066.88
2952.0
104.32
256.0
191.27
606.0
422.0
42.73
184.0
16
2222.0
701961.0
1204.56
3248.0
111.52
256.0
221.83
457.0
327.0
59.22
139.0
1147.40
3132.0
112.64
256.0
193.38
451.0
313.0
43.17
146.0
17
2222.0
701961.0
1330.36
3300.0
123.52
256.0
242.20
567.0
422.0
66.55
145.0
1027.52
2596.0
100.96
256.0
185.21
463.0
334.0
41.42
129.0
18
2222.0
701961.0
1213.20
3084.0
111.68
256.0
228.41
589.0
412.0
62.69
177.0
1103.96
2960.0
105.92
256.0
189.53
394.0
284.0
41.54
110.0
19
2222.0
701961.0
1218.40
2980.0
112.80
256.0
221.69
493.0
357.0
59.54
136.0
1099.28
3128.0
110.72
256.0
187.56
462.0
375.0
40.01
113.0
20
2222.0
701961.0
1268.40
3232.0
119.68
256.0
228.06
466.0
350.0
62.35
133.0
1181.40
3132.0
117.44
256.0
201.91
502.0
388.0
43.97
125.0
21
2222.0
701961.0
1239.72
3244.0
113.76
256.0
227.71
482.0
368.0
61.30
145.0
1005.04
2416.0
96.96
256.0
184.36
418.0
294.0
41.90
126.0
22
2222.0
701961.0
1250.28
3232.0
114.72
256.0
239.01
476.0
346.0
64.87
130.0
1156.44
3080.0
112.00
256.0
200.37
538.0
401.0
45.59
137.0
23
2222.0
701961.0
1230.48
3164.0
116.32
256.0
226.15
477.0
362.0
61.10
126.0
1047.68
2796.0
104.32
256.0
182.99
390.0
303.0
39.94
101.0
24
2222.0
701961.0
1250.48
3144.0
114.88
256.0
227.21
565.0
416.0
61.94
149.0
1047.56
3000.0
102.24
256.0
189.47
572.0
411.0
42.64
161.0
25
2222.0
701961.0
1157.80
4356.0
105.60
256.0
216.69
695.0
496.0
58.30
199.0
1010.32
2400.0
99.04
256.0
184.04
347.0
282.0
40.29
116.0
26
2222.0
701961.0
1175.20
3120.0
108.96
256.0
221.37
519.0
352.0
60.83
168.0
1119.64
3180.0
109.12
256.0
194.71
478.0
373.0
42.08
118.0
27
2222.0
701961.0
1025.48
2432.0
92.80
256.0
203.58
390.0
282.0
55.92
131.0
1044.48
2952.0
102.08
256.0
185.94
447.0
321.0
42.52
165.0
28
2222.0
701961.0
1062.68
2660.0
100.48
256.0
194.81
391.0
291.0
52.72
102.0
1050.40
2652.0
101.92
256.0
190.37
419.0
316.0
43.65
120.0
29
2222.0
701961.0
1182.48
2956.0
109.12
256.0
220.75
577.0
404.0
60.05
173.0
1124.96
2412.0
110.08
256.0
199.56
463.0
316.0
44.98
147.0
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
70
2222.0
701961.0
1155.52
3096.0
109.12
256.0
212.23
512.0
361.0
57.06
151.0
1085.96
2752.0
105.60
256.0
189.10
414.0
313.0
42.92
112.0
71
2222.0
701961.0
1192.88
2496.0
112.48
256.0
224.75
453.0
335.0
61.69
142.0
1058.00
3156.0
103.36
256.0
179.59
438.0
340.0
37.79
106.0
72
2222.0
701961.0
1186.24
2952.0
114.08
256.0
223.90
456.0
312.0
61.14
146.0
1095.56
2728.0
105.92
256.0
193.14
690.0
497.0
43.15
193.0
73
2222.0
701961.0
1323.08
3568.0
117.92
256.0
248.95
714.0
507.0
69.08
207.0
1098.64
2968.0
105.28
256.0
195.78
586.0
405.0
44.14
181.0
74
2222.0
701961.0
1224.92
3092.0
114.24
256.0
231.78
495.0
366.0
61.54
129.0
1141.72
2676.0
113.60
256.0
202.14
468.0
359.0
46.57
126.0
75
2222.0
701961.0
1192.60
3232.0
109.76
256.0
227.44
498.0
353.0
62.38
145.0
1071.68
2888.0
105.92
256.0
192.63
488.0
348.0
45.74
142.0
76
2222.0
701961.0
1279.12
4456.0
118.40
512.0
234.05
675.0
454.0
63.33
221.0
1069.56
2768.0
104.80
256.0
196.96
490.0
325.0
47.05
165.0
77
2222.0
701961.0
1302.04
3176.0
116.48
256.0
244.82
525.0
390.0
67.19
160.0
1041.80
3036.0
103.04
256.0
184.44
447.0
340.0
40.79
107.0
78
2222.0
701961.0
1338.88
4292.0
126.72
256.0
244.06
642.0
453.0
66.96
189.0
1191.88
3024.0
116.16
256.0
205.91
465.0
331.0
45.37
134.0
79
2222.0
701961.0
1193.20
3168.0
105.28
256.0
225.68
525.0
372.0
61.89
153.0
1064.76
3124.0
103.52
256.0
187.40
441.0
354.0
41.05
119.0
80
2222.0
701961.0
1180.84
2972.0
110.40
256.0
221.70
547.0
360.0
61.87
187.0
1120.44
2808.0
113.28
256.0
194.76
561.0
416.0
44.26
145.0
81
2222.0
701961.0
1192.40
3300.0
106.08
256.0
226.22
550.0
376.0
63.05
174.0
1097.48
2824.0
108.96
256.0
193.32
411.0
301.0
43.52
133.0
82
2222.0
701961.0
1271.44
4688.0
115.52
512.0
233.19
587.0
432.0
63.41
157.0
1184.68
2936.0
118.72
256.0
199.91
479.0
351.0
44.15
128.0
83
2222.0
701961.0
1220.32
3588.0
113.92
256.0
225.79
529.0
378.0
61.63
151.0
1096.32
2864.0
110.40
256.0
184.34
487.0
342.0
40.26
145.0
84
2222.0
701961.0
1142.28
3256.0
105.44
256.0
208.72
523.0
371.0
55.10
152.0
1168.48
2840.0
114.56
256.0
200.48
473.0
356.0
44.81
137.0
85
2222.0
701961.0
1197.56
3396.0
107.20
256.0
225.04
577.0
407.0
61.21
170.0
1020.80
2488.0
101.44
256.0
180.26
363.0
276.0
39.98
112.0
86
2222.0
701961.0
1216.20
3120.0
115.52
256.0
221.49
498.0
343.0
60.68
155.0
1146.88
3096.0
113.60
256.0
195.67
414.0
321.0
42.96
126.0
87
2222.0
701961.0
1180.68
3248.0
110.08
256.0
225.12
477.0
327.0
61.40
150.0
1050.80
2796.0
103.36
256.0
183.82
581.0
435.0
40.91
146.0
88
2222.0
701961.0
1099.68
2788.0
102.40
256.0
209.54
449.0
325.0
56.87
131.0
1023.64
2768.0
101.12
256.0
186.23
490.0
317.0
42.64
173.0
89
2222.0
701961.0
1153.28
3480.0
104.32
256.0
221.64
558.0
392.0
60.84
181.0
1089.64
2964.0
106.56
256.0
189.89
455.0
344.0
41.20
143.0
90
2222.0
701961.0
1249.28
2904.0
120.32
256.0
229.48
464.0
340.0
61.22
132.0
1066.16
2792.0
107.68
256.0
189.94
483.0
366.0
42.07
117.0
91
2222.0
701961.0
1137.40
3200.0
103.84
256.0
216.56
514.0
340.0
58.79
174.0
1136.28
3124.0
107.84
256.0
198.97
451.0
358.0
44.01
118.0
92
2222.0
701961.0
1336.88
5068.0
121.76
512.0
240.93
684.0
482.0
65.71
212.0
1062.32
2752.0
101.44
256.0
192.62
508.0
367.0
43.65
141.0
93
2222.0
701961.0
1207.44
3040.0
108.64
256.0
226.68
440.0
317.0
62.11
123.0
1078.44
2752.0
104.32
256.0
190.12
384.0
308.0
42.36
127.0
94
2222.0
701961.0
1157.20
3196.0
105.44
256.0
216.62
501.0
363.0
58.71
138.0
1139.48
3068.0
110.08
256.0
197.63
499.0
360.0
43.71
172.0
95
2222.0
701961.0
1231.00
4028.0
110.72
256.0
230.57
567.0
427.0
62.12
181.0
1163.92
2392.0
115.68
256.0
202.55
490.0
325.0
46.72
165.0
96
2222.0
701961.0
1142.84
3176.0
104.16
256.0
216.69
510.0
359.0
58.75
181.0
1104.60
3300.0
109.44
256.0
187.22
485.0
358.0
40.55
127.0
97
2222.0
701961.0
1203.32
3992.0
111.20
256.0
226.51
538.0
384.0
61.75
154.0
1030.16
3204.0
101.44
256.0
182.46
540.0
399.0
40.18
141.0
98
2222.0
701961.0
1226.96
3144.0
114.72
256.0
224.22
483.0
344.0
60.31
140.0
1099.64
3232.0
109.28
256.0
185.31
495.0
370.0
40.11
139.0
99
2222.0
701961.0
1283.04
2708.0
119.68
256.0
234.18
537.0
376.0
63.66
161.0
1092.76
2956.0
104.64
256.0
189.79
477.0
376.0
42.40
130.0
100 rows × 20 columns
In [82]:
plt.figure(figsize=(15, 10))
plt.suptitle("""NN play vs. Random play
Mean episode score is from 100 tests after each epoch.
training for 100 epochs, each 2222 games (701961 observations)""", fontsize=14, fontweight="bold")
plt.plot(data_supervised.index.get_values(), data_supervised["nn"]["mean_episode_score"], "g", label="NN play")
plt.plot(data_supervised.index.get_values(), data_supervised["random"]["mean_episode_score"], "r", label="Random play")
plt.xlabel("epoch", fontsize=14)
plt.ylabel("mean episode score", fontsize=14)
plt.legend(fontsize=15)
plt.savefig("images/play_supervised.png", dpi=800)
In [ ]:
Content source: gorgitko/MI-MVI_2016
Similar notebooks: