In [4]:
from os import environ

environ['optimizer'] = 'Adam'
environ['num_workers']= '2'
environ['batch_size']= str(2048)
environ['n_epochs']= '1000'
environ['batch_norm']= 'True'
environ['loss_func']='MAPE'
environ['layers'] = '600 350 200 180'
environ['dropouts'] = '0.1 '* 4
environ['log'] = 'False'
environ['weight_decay'] = '0.01'
environ['cuda_device'] ='cuda:4'
environ['dataset'] = 'data/speedup_dataset2.pkl'

%run utils.ipynb

In [5]:
train_dl, val_dl, test_dl = train_dev_split(dataset, batch_size, num_workers, log=log)

db = fai.basic_data.DataBunch(train_dl, val_dl, test_dl, device=device)

In [6]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]


model = None 

if batch_norm:
    model = Model_BN(input_size, output_size, hidden_sizes=layers_sizes, drops=drops)
else:
    model = Model(input_size, output_size)
    
if loss_func == 'MSE':
    criterion = nn.MSELoss()
elif loss_func == 'MAPE':
    criterion = mape_criterion
elif loss_func == 'SMAPE':
    criterion = smape_criterion

l = fai.Learner(db, model, loss_func=criterion, metrics=[mape_criterion, rmse_criterion])

if optimizer == 'SGD':
    l.opt_func = optim.SGD

In [4]:
l = l.load(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")

In [125]:
l.lr_find()


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

In [126]:
l.recorder.plot()



In [7]:
lr = 1e-03

In [8]:
l.fit_one_cycle(600, lr)


Total time: 19:09

epoch train_loss valid_loss mape_criterion rmse_criterion
1 102.462837 89.915024 89.915024 2.250221
2 95.942848 89.159119 89.159119 2.241213
3 93.677299 88.585869 88.585869 2.241689
4 92.049332 88.617554 88.617554 2.237574
5 90.969185 88.596176 88.596176 2.234123
6 91.528877 88.353172 88.353172 2.234489
7 89.688362 87.845184 87.845184 2.229401
8 88.902145 86.188789 86.188789 2.225424
9 87.189957 84.940178 84.940178 2.218786
10 86.617989 84.251053 84.251053 2.217868
11 86.144867 83.626022 83.626022 2.210706
12 84.905479 82.636887 82.636887 2.206949
13 83.870193 81.844536 81.844536 2.205022
14 83.264206 81.042114 81.042114 2.195430
15 82.299522 80.251030 80.251030 2.184253
16 81.930260 79.457634 79.457634 2.179665
17 80.887802 78.319626 78.319626 2.164439
18 80.192841 77.398598 77.398598 2.154759
19 79.316902 76.255775 76.255775 2.138506
20 78.167198 74.610641 74.610641 2.110413
21 77.248169 73.121490 73.121490 2.083524
22 76.619637 71.899475 71.899475 2.071931
23 75.689087 70.554329 70.554329 2.049546
24 75.462486 69.788422 69.788422 2.034508
25 74.128059 69.134758 69.134758 2.027010
26 74.013397 69.124947 69.124947 2.017795
27 73.529907 68.545235 68.545235 2.015644
28 72.503899 67.715485 67.715485 1.993982
29 72.018829 67.323875 67.323875 1.992565
30 71.861832 68.957306 68.957306 2.016355
31 71.359268 66.459358 66.459358 1.967845
32 70.976089 66.587563 66.587563 1.964452
33 70.218262 65.609566 65.609566 1.942873
34 70.110603 66.536560 66.536560 1.964318
35 69.803085 66.303574 66.303574 1.955917
36 69.532188 66.222702 66.222702 1.968754
37 69.123528 65.792366 65.792366 1.954502
38 68.652756 64.899185 64.899185 1.923390
39 68.343796 64.488197 64.488197 1.917168
40 67.861549 64.497505 64.497505 1.918128
41 67.511017 64.443481 64.443481 1.920465
42 67.215637 63.580688 63.580688 1.912533
43 66.639343 64.007141 64.007141 1.894527
44 65.877548 63.246277 63.246277 1.880019
45 65.822037 63.621712 63.621712 1.931469
46 65.259453 61.078979 61.078979 1.871760
47 64.938972 61.761864 61.761864 1.915089
48 64.027245 59.030312 59.030312 1.820233
49 63.683167 59.898449 59.898449 1.863873
50 63.381924 60.149769 60.149769 1.868998
51 61.532448 57.491699 57.491699 1.811738
52 61.413368 58.958481 58.958481 1.869629
53 60.256184 57.546345 57.546345 1.782692
54 61.000992 59.549824 59.549824 1.852500
55 59.750122 56.763351 56.763351 1.815845
56 58.880177 57.024956 57.024956 1.747679
57 57.977936 57.685349 57.685349 1.818881
58 57.351383 55.792374 55.792374 1.740868
59 61.805122 62.690342 62.690342 1.932269
60 62.743099 60.701401 60.701401 1.881479
61 62.275963 60.291924 60.291924 1.869433
62 60.954506 57.321407 57.321407 1.774410
63 57.536522 55.480381 55.480381 1.769087
64 59.621834 60.492695 60.492695 1.861783
65 58.414085 56.202839 56.202839 1.778402
66 58.000450 55.948418 55.948418 1.778747
67 57.790462 61.868870 61.868870 1.865341
68 57.708767 58.825874 58.825874 1.807891
69 54.291729 55.824299 55.824299 1.769154
70 53.209080 58.772564 58.772564 1.764153
71 52.375916 54.450249 54.450249 1.699080
72 51.397118 54.344074 54.344074 1.719169
73 49.561699 53.976311 53.976311 1.633079
74 49.862606 56.884064 56.884064 1.720541
75 49.479145 62.282661 62.282661 1.823001
76 48.908146 56.493301 56.493301 1.732743
77 48.634907 56.730099 56.730099 1.768547
78 48.033245 57.653049 57.653049 1.701744
79 46.540295 50.654930 50.654930 1.589988
80 47.030609 55.185963 55.185963 1.620118
81 46.125957 53.473927 53.473927 1.603792
82 44.691105 48.369156 48.369156 1.531184
83 44.350109 55.723076 55.723076 1.622825
84 45.898518 62.406368 62.406368 1.709704
85 45.058266 48.468037 48.468037 1.504565
86 43.726620 48.691936 48.691936 1.493813
87 44.156380 49.913189 49.913189 1.574697
88 42.925648 49.838276 49.838276 1.538066
89 42.416237 50.906349 50.906349 1.544893
90 41.707027 52.795776 52.795776 1.637186
91 41.388653 55.610207 55.610207 1.586232
92 42.173203 48.229458 48.229458 1.549769
93 40.100197 48.705097 48.705097 1.440520
94 41.953983 54.809658 54.809658 1.632107
95 41.144203 56.031155 56.031155 1.438029
96 42.218178 48.634567 48.634567 1.523352
97 41.702614 53.788464 53.788464 1.548024
98 40.408802 49.433216 49.433216 1.472887
99 39.715084 48.257210 48.257210 1.435997
100 44.082287 52.253365 52.253365 1.596702
101 42.178127 51.293236 51.293236 1.543357
102 40.538998 45.042782 45.042782 1.406151
103 38.412235 42.888336 42.888336 1.261014
104 37.880253 45.432343 45.432343 1.371741
105 39.726032 49.939201 49.939201 1.447204
106 41.930012 48.238358 48.238358 1.397959
107 39.206535 43.298393 43.298393 1.411555
108 38.627396 53.890980 53.890980 1.542061
109 38.115311 49.441555 49.441555 1.538206
110 37.774334 41.115311 41.115311 1.280305
111 37.242725 50.177139 50.177139 1.455027
112 37.662914 49.813534 49.813534 1.535243
113 37.925167 43.111256 43.111256 1.296932
114 37.597553 47.469551 47.469551 1.489701
115 36.807793 45.197067 45.197067 1.384499
116 37.096867 47.770855 47.770855 1.537368
117 35.754456 47.517849 47.517849 1.427904
118 37.579514 49.060925 49.060925 1.559251
119 36.904114 49.449589 49.449589 1.460099
120 36.047928 41.096149 41.096149 1.327718
121 35.932812 48.796745 48.796745 1.249137
122 36.227360 37.428486 37.428486 1.140873
123 36.399853 51.606255 51.606255 1.479480
124 36.774441 50.831142 50.831142 1.331906
125 35.642670 41.459324 41.459324 1.196531
126 36.321976 39.811092 39.811092 1.238446
127 34.772648 31.381090 31.381090 0.985488
128 34.303116 40.470329 40.470329 1.207205
129 34.631145 40.581608 40.581608 1.217004
130 34.646908 41.931065 41.931065 1.269815
131 34.482819 43.965099 43.965099 1.370830
132 33.739220 34.368599 34.368599 1.106400
133 34.495720 40.929226 40.929226 1.324431
134 36.684319 50.006027 50.006027 1.643450
135 34.724823 42.033699 42.033699 1.208778
136 34.700245 38.832420 38.832420 1.186832
137 34.516644 42.992619 42.992619 1.293013
138 33.454983 33.238853 33.238853 1.083079
139 33.991093 35.447414 35.447414 1.103924
140 33.476212 39.205551 39.205551 1.126302
141 33.460987 40.208801 40.208801 1.260162
142 34.475582 39.001701 39.001701 1.181114
143 33.446178 36.023533 36.023533 1.117951
144 34.918251 49.728058 49.728058 1.525955
145 33.617649 37.027851 37.027851 1.177658
146 32.607128 40.295261 40.295261 1.218831
147 34.180916 45.603886 45.603886 1.405544
148 34.175728 43.179070 43.179070 1.211370
149 34.803402 38.488731 38.488731 1.199201
150 33.536026 36.844143 36.844143 1.105938
151 32.863754 41.294613 41.294613 1.243217
152 33.102325 38.399017 38.399017 1.187201
153 33.886322 39.514030 39.514030 1.214723
154 32.800175 34.625404 34.625404 1.080302
155 31.884401 38.613480 38.613480 1.294129
156 33.187286 37.232204 37.232204 1.249735
157 32.381664 37.612865 37.612865 1.220467
158 32.255428 34.255901 34.255901 1.100269
159 31.848295 37.088108 37.088108 1.198890
160 30.968899 31.360840 31.360840 1.011190
161 31.737398 37.021214 37.021214 1.119648
162 35.513985 51.170986 51.170986 1.619334
163 32.419483 32.118145 32.118145 1.042945
164 32.267788 36.559731 36.559731 1.122198
165 32.265480 38.774670 38.774670 1.057342
166 31.750093 31.954960 31.954960 0.993154
167 32.550919 39.066200 39.066200 1.232327
168 31.056742 33.325405 33.325405 1.041935
169 31.377750 33.399265 33.399265 0.989797
170 33.871483 42.348354 42.348354 1.332318
171 31.340851 29.980474 29.980474 0.965288
172 31.399803 35.302532 35.302532 1.103125
173 33.367283 33.063938 33.063938 0.965582
174 31.520449 31.007048 31.007048 0.968705
175 33.009102 41.270267 41.270267 1.149267
176 32.392178 28.921974 28.921974 0.916448
177 31.051737 31.076653 31.076653 0.947544
178 30.386816 28.703388 28.703388 0.921619
179 30.807390 28.721554 28.721554 0.929247
180 30.424442 30.672453 30.672453 0.941073
181 30.349003 32.410038 32.410038 1.036487
182 30.194866 28.087044 28.087044 0.950450
183 30.263041 27.578218 27.578218 0.897694
184 29.797739 29.888056 29.888056 0.948301
185 29.935583 30.774916 30.774916 0.985125
186 29.459436 29.052956 29.052956 0.981487
187 33.849144 35.500896 35.500896 1.212891
188 32.666126 36.224243 36.224243 1.178105
189 30.029305 28.543825 28.543825 0.905285
190 32.766960 38.235004 38.235004 1.261275
191 31.579334 31.061453 31.061453 0.995024
192 31.259674 33.387699 33.387699 1.076000
193 31.128372 30.868631 30.868631 0.964490
194 30.971384 32.275379 32.275379 0.961773
195 30.854416 31.670696 31.670696 0.990046
196 30.184906 30.459307 30.459307 0.984591
197 30.470558 33.204475 33.204475 0.970237
198 30.040091 33.328411 33.328411 0.996078
199 29.553257 31.312305 31.312305 0.948989
200 29.600702 29.081434 29.081434 0.956585
201 29.032640 30.800125 30.800125 0.927248
202 29.429724 31.947538 31.947538 0.993270
203 29.506773 28.804781 28.804781 0.914111
204 29.326370 30.516644 30.516644 0.910598
205 29.402994 36.317833 36.317833 0.990869
206 29.002129 30.929520 30.929520 0.870361
207 30.279745 33.261204 33.261204 0.920278
208 29.205072 27.733316 27.733316 0.907874
209 28.521025 35.237633 35.237633 1.096733
210 28.497412 28.038363 28.038363 0.943656
211 31.133661 32.748589 32.748589 0.920423
212 30.033360 27.724091 27.724091 0.908205
213 28.693943 28.181000 28.181000 0.942253
214 28.000706 30.905306 30.905306 1.051152
215 29.153500 30.739197 30.739197 0.928864
216 28.129633 29.070465 29.070465 0.872321
217 28.315338 27.875219 27.875219 0.889311
218 27.600582 30.403532 30.403532 0.907008
219 27.180147 28.719963 28.719963 0.912207
220 27.602861 30.189648 30.189648 0.916323
221 27.851830 30.176502 30.176502 0.967201
222 28.132318 34.883808 34.883808 0.918503
223 27.547482 28.726351 28.726351 0.875762
224 27.841446 31.372299 31.372299 0.975907
225 27.053761 28.012924 28.012924 0.883935
226 27.428524 28.919622 28.919622 0.886396
227 27.623837 29.478462 29.478462 0.924545
228 26.549494 28.356609 28.356609 0.888520
229 26.166252 29.063196 29.063196 0.838485
230 26.995083 32.883877 32.883877 0.881544
231 26.979536 26.896238 26.896238 0.854347
232 26.347513 33.861176 33.861176 0.857156
233 26.862850 30.716005 30.716005 0.930571
234 26.682184 28.093563 28.093563 0.901258
235 26.211510 27.587170 27.587170 0.870875
236 26.153074 27.288422 27.288422 0.881822
237 26.078346 27.790150 27.790150 0.921486
238 26.404404 29.168247 29.168247 0.919915
239 26.428675 26.849258 26.849258 0.873241
240 26.542526 34.971981 34.971981 0.889436
241 25.921844 28.955263 28.955263 0.849940
242 25.856985 27.943659 27.943659 0.938976
243 25.410679 32.169914 32.169914 0.888684
244 25.466766 28.043381 28.043381 0.904460
245 25.664127 26.928066 26.928066 0.865609
246 25.406288 27.216732 27.216732 0.885839
247 25.136805 30.381466 30.381466 0.961977
248 25.960550 27.826349 27.826349 0.922123
249 25.955774 32.337696 32.337696 0.920436
250 25.592505 26.423412 26.423412 0.889608
251 24.991760 27.597250 27.597250 0.861189
252 25.949627 34.347752 34.347752 0.913615
253 26.009987 28.814590 28.814590 0.915201
254 25.342943 31.485123 31.485123 0.875107
255 24.925245 31.605396 31.605396 0.907865
256 25.129282 35.692726 35.692726 0.936480
257 25.764706 36.718327 36.718327 0.861335
258 25.515604 28.572878 28.572878 0.872104
259 24.672745 29.852816 29.852816 0.884588
260 24.553434 27.237350 27.237350 0.853266
261 24.927618 26.539103 26.539103 0.848360
262 24.688866 27.940100 27.940100 0.869520
263 24.790091 34.699253 34.699253 0.846236
264 24.550695 29.528303 29.528303 0.847367
265 24.272127 26.570253 26.570253 0.889928
266 24.099630 31.263144 31.263144 0.953881
267 24.633497 27.708530 27.708530 0.903770
268 24.316227 33.024181 33.024181 0.872730
269 24.144169 31.144737 31.144737 0.916001
270 24.194763 32.720905 32.720905 0.896790
271 24.339458 27.281113 27.281113 0.836668
272 24.091272 29.804873 29.804873 0.866757
273 24.092991 35.216183 35.216183 0.930799
274 23.619629 33.071480 33.071480 0.855799
275 23.990950 33.026218 33.026218 0.902117
276 24.407785 36.179108 36.179108 0.863402
277 23.897045 29.479143 29.479143 0.845802
278 23.637547 27.101412 27.101412 0.877575
279 23.636490 31.237318 31.237318 0.885830
280 23.524809 28.226509 28.226509 0.888378
281 23.535900 29.908449 29.908449 0.881832
282 23.579260 31.401525 31.401525 0.923868
283 23.483599 28.870726 28.870726 0.868786
284 23.333834 29.210741 29.210741 0.859095
285 23.539259 31.370253 31.370253 0.841539
286 23.517559 29.477880 29.477880 0.854113
287 23.327291 27.406416 27.406416 0.864116
288 23.278925 28.188749 28.188749 0.864451
289 23.204361 26.725651 26.725651 0.879889
290 23.237856 28.023657 28.023657 0.877536
291 22.992851 30.073887 30.073887 0.846566
292 22.998302 27.207613 27.207613 0.845029
293 22.701468 30.449644 30.449644 0.867035
294 23.232189 31.119488 31.119488 0.847104
295 22.928455 27.828276 27.828276 0.826710
296 22.889984 26.132402 26.132402 0.837898
297 22.581001 28.343269 28.343269 0.870220
298 22.613939 30.899712 30.899712 0.882964
299 22.557434 32.113148 32.113148 0.849321
300 22.505795 30.226730 30.226730 0.855375
301 22.548832 35.589844 35.589844 0.873132
302 22.543770 28.692303 28.692303 0.842034
303 22.701599 27.767765 27.767765 0.904211
304 22.400715 27.173031 27.173031 0.860667
305 22.197437 27.604269 27.604269 0.851240
306 22.338448 26.485975 26.485975 0.843565
307 22.221752 31.686926 31.686926 0.855039
308 22.269285 29.615650 29.615650 0.872594
309 22.107435 28.479919 28.479919 0.842593
310 22.405035 32.741665 32.741665 0.900984
311 22.369503 29.492785 29.492785 0.888509
312 22.380257 28.602345 28.602345 0.861227
313 22.052174 31.440632 31.440632 0.846058
314 21.847219 33.097412 33.097412 0.862574
315 22.089762 30.850168 30.850168 0.870820
316 22.143324 27.808676 27.808676 0.886002
317 21.871441 28.382538 28.382538 0.872122
318 21.869408 31.890919 31.890919 0.844921
319 22.021996 33.792290 33.792290 0.810507
320 21.682329 32.085899 32.085899 0.852052
321 21.485136 30.800472 30.800472 0.852270
322 21.554802 27.656782 27.656782 0.838214
323 21.582949 30.987406 30.987406 0.874131
324 21.684252 27.577974 27.577974 0.854570
325 21.649443 28.866381 28.866381 0.874095
326 21.800903 28.920055 28.920055 0.830214
327 21.779699 27.880869 27.880869 0.833062
328 21.485964 33.646626 33.646626 0.842734
329 21.803717 32.444542 32.444542 0.837914
330 21.330885 25.325987 25.325987 0.836160
331 21.149656 27.959869 27.959869 0.837313
332 21.619190 32.521076 32.521076 0.829198
333 21.277458 26.951662 26.951662 0.868845
334 21.019983 28.715425 28.715425 0.822239
335 20.938623 27.484371 27.484371 0.841657
336 21.112318 31.391222 31.391222 0.874959
337 21.314363 31.533684 31.533684 0.846035
338 21.258247 34.287395 34.287395 0.857185
339 21.088907 31.995419 31.995419 0.889313
340 21.060400 25.391491 25.391491 0.850021
341 21.191385 29.911272 29.911272 0.846361
342 21.239538 26.181438 26.181438 0.832716
343 20.908230 27.911318 27.911318 0.866273
344 20.663805 26.280628 26.280628 0.821167
345 20.665287 28.996469 28.996469 0.838689
346 20.797508 28.352690 28.352690 0.863336
347 20.920387 27.500000 27.500000 0.862113
348 20.676994 28.133013 28.133013 0.830970
349 20.972519 29.595919 29.595919 0.836804
350 20.676264 32.270893 32.270893 0.806910
351 20.843134 27.261044 27.261044 0.838302
352 20.774733 33.518471 33.518471 0.863284
353 20.609516 27.182146 27.182146 0.840244
354 20.543415 27.878954 27.878954 0.849199
355 20.385962 28.484194 28.484194 0.863096
356 20.492867 25.872372 25.872372 0.807848
357 20.535784 24.678133 24.678133 0.823150
358 20.517176 29.534426 29.534426 0.849546
359 20.470367 26.525784 26.525784 0.847718
360 20.341015 28.980532 28.980532 0.875093
361 20.425688 29.628828 29.628828 0.829189
362 20.493580 28.659624 28.659624 0.824404
363 20.421024 34.806728 34.806728 0.849442
364 20.164429 25.768372 25.768372 0.828425
365 20.437984 29.775827 29.775827 0.848319
366 20.249952 35.833237 35.833237 0.866141
367 20.339994 28.918982 28.918982 0.830377
368 20.284470 29.664816 29.664816 0.851417
369 20.330862 28.552240 28.552240 0.826999
370 20.048162 29.763744 29.763744 0.847629
371 20.228447 35.931984 35.931984 0.872037
372 19.881866 34.702980 34.702980 0.855211
373 20.269339 34.126816 34.126816 0.821539
374 19.969322 37.500626 37.500626 0.855052
375 19.955568 35.221851 35.221851 0.852405
376 20.342201 40.740067 40.740067 0.855306
377 19.976448 27.932146 27.932146 0.847538
378 20.265495 29.188465 29.188465 0.814933
379 19.943855 30.165432 30.165432 0.831588
380 19.907921 28.370747 28.370747 0.811003
381 19.806030 28.747881 28.747881 0.827588
382 19.566782 29.042213 29.042213 0.815333
383 19.962469 27.231894 27.231894 0.812813
384 19.918219 31.876888 31.876888 0.865599
385 19.670238 30.495247 30.495247 0.826349
386 20.088600 28.658175 28.658175 0.835207
387 19.868584 31.086962 31.086962 0.838345
388 19.701395 32.066730 32.066730 0.835094
389 19.621452 28.678385 28.678385 0.808998
390 19.700653 24.997456 24.997456 0.803856
391 19.574583 29.561003 29.561003 0.821678
392 19.870275 28.705059 28.705059 0.811017
393 19.639271 28.516575 28.516575 0.819947
394 19.638912 30.155981 30.155981 0.833960
395 19.592838 29.500837 29.500837 0.841343
396 19.401573 28.972284 28.972284 0.829232
397 19.476780 25.117693 25.117693 0.800549
398 19.415213 27.676556 27.676556 0.817315
399 19.353481 29.587238 29.587238 0.835576
400 19.329996 30.589844 30.589844 0.814549
401 19.388426 28.798521 28.798521 0.796778
402 19.391117 27.288851 27.288851 0.820750
403 19.465256 31.393837 31.393837 0.836368
404 19.367193 32.006042 32.006042 0.836113
405 19.529978 28.048588 28.048588 0.818935
406 19.145683 29.942076 29.942076 0.831204
407 19.258900 29.778469 29.778469 0.806376
408 19.101170 28.129093 28.129093 0.901656
409 19.261408 32.255196 32.255196 0.827982
410 19.313683 28.843971 28.843971 0.821732
411 19.165995 28.464394 28.464394 0.829256
412 19.210093 35.132942 35.132942 0.832075
413 19.264273 30.158085 30.158085 0.823937
414 19.138166 30.775288 30.775288 0.842165
415 18.951040 30.801260 30.801260 0.837962
416 19.047180 33.842201 33.842201 0.877818
417 18.937719 31.111700 31.111700 0.839707
418 18.912231 26.090418 26.090418 0.819153
419 18.981651 25.895103 25.895103 0.810066
420 19.009682 33.740959 33.740959 0.825031
421 19.068413 26.207594 26.207594 0.800084
422 19.138927 27.260050 27.260050 0.826087
423 19.042442 38.317123 38.317123 0.855723
424 18.792263 30.573797 30.573797 0.821726
425 18.854481 30.468676 30.468676 0.832594
426 18.829100 28.783182 28.783182 0.828581
427 18.863146 28.578630 28.578630 0.801191
428 18.836004 27.001785 27.001785 0.794121
429 18.850929 32.737068 32.737068 0.842025
430 18.769400 35.942917 35.942917 0.829415
431 18.702246 26.559973 26.559973 0.800133
432 18.844038 28.286423 28.286423 0.820270
433 18.560753 28.310850 28.310850 0.821260
434 18.794085 31.361650 31.361650 0.812626
435 18.681713 31.365549 31.365549 0.848335
436 18.713829 33.909851 33.909851 0.825864
437 18.699177 27.860062 27.860062 0.840470
438 18.622734 28.208563 28.208563 0.805614
439 18.489710 26.838825 26.838825 0.810255
440 18.533819 30.129074 30.129074 0.844464
441 18.456362 29.615694 29.615694 0.817194
442 18.592129 30.206934 30.206934 0.817924
443 18.747429 29.960958 29.960958 0.821688
444 18.497389 32.355579 32.355579 0.825524
445 18.496330 33.963696 33.963696 0.839722
446 18.526970 27.462112 27.462112 0.807560
447 18.463797 30.403719 30.403719 0.832315
448 18.292990 29.474281 29.474281 0.806896
449 18.556654 27.655251 27.655251 0.793418
450 18.317999 31.544628 31.544628 0.822906
451 18.420084 29.626957 29.626957 0.811272
452 18.295961 27.606194 27.606194 0.810014
453 18.390396 28.956913 28.956913 0.814843
454 18.216801 28.018705 28.018705 0.801500
455 18.274914 30.064688 30.064688 0.815990
456 18.305866 31.102724 31.102724 0.828104
457 18.174770 32.018856 32.018856 0.821333
458 18.388660 29.696213 29.696213 0.824469
459 18.236008 27.600956 27.600956 0.797204
460 18.253284 26.765680 26.765680 0.793056
461 18.275747 30.504982 30.504982 0.797631
462 18.105152 29.080460 29.080460 0.807240
463 18.224957 29.199797 29.199797 0.810738
464 17.977352 30.023697 30.023697 0.806384
465 18.100113 28.936775 28.936775 0.805018
466 18.182266 35.798580 35.798580 0.841455
467 18.075848 29.712828 29.712828 0.808752
468 18.069944 30.718275 30.718275 0.841107
469 18.028542 31.286587 31.286587 0.818650
470 18.039639 27.987673 27.987673 0.790864
471 18.049980 26.684250 26.684250 0.809462
472 18.041315 29.130526 29.130526 0.798955
473 18.041491 34.469986 34.469986 0.832688
474 17.873390 25.967432 25.967432 0.804805
475 18.024952 26.198193 26.198193 0.795796
476 18.130011 26.925938 26.925938 0.808613
477 17.924801 29.644575 29.644575 0.809224
478 18.003313 30.549294 30.549294 0.812313
479 18.007442 34.245888 34.245888 0.839452
480 17.958710 27.796597 27.796597 0.816399
481 17.787209 32.129025 32.129025 0.826926
482 17.836800 29.350918 29.350918 0.806177
483 17.838070 27.451387 27.451387 0.811297
484 17.926260 28.854450 28.854450 0.794694
485 17.941809 29.902784 29.902784 0.810505
486 17.807796 31.584600 31.584600 0.828748
487 18.123257 25.971964 25.971964 0.802340
488 17.923618 30.926918 30.926918 0.807232
489 18.017157 34.957954 34.957954 0.828751
490 17.863190 35.238113 35.238113 0.826141
491 17.845936 28.953009 28.953009 0.810183
492 17.760710 30.270910 30.270910 0.810206
493 17.636686 29.655172 29.655172 0.792839
494 17.646048 27.377762 27.377762 0.795583
495 17.718960 27.934299 27.934299 0.801404
496 17.782598 27.003584 27.003584 0.789616
497 17.744152 29.458599 29.458599 0.805626
498 17.659798 27.889635 27.889635 0.803372
499 17.586069 27.337358 27.337358 0.804841
500 17.465101 27.234043 27.234043 0.796919
501 17.617983 32.175045 32.175045 0.828921
502 17.517565 28.652172 28.652172 0.801894
503 17.607765 28.732475 28.732475 0.797819
504 17.634298 29.661631 29.661631 0.811863
505 17.451544 28.194799 28.194799 0.812533
506 17.535856 27.850016 27.850016 0.802811
507 17.653004 28.458843 28.458843 0.815564
508 17.485518 27.116488 27.116488 0.789522
509 17.657616 31.270426 31.270426 0.816570
510 17.550051 28.537813 28.537813 0.794867
511 17.456673 26.983768 26.983768 0.808765
512 17.461521 32.131729 32.131729 0.811588
513 17.400482 32.266327 32.266327 0.810670
514 17.409887 27.117775 27.117775 0.791690
515 17.444668 28.755241 28.755241 0.802858
516 17.439150 27.350828 27.350828 0.791582
517 17.425848 27.947987 27.947987 0.792039
518 17.374071 29.243031 29.243031 0.813625
519 17.407967 32.309826 32.309826 0.801910
520 17.316252 29.545788 29.545788 0.817326
521 17.313868 28.906019 28.906019 0.810315
522 17.410135 30.232071 30.232071 0.817198
523 17.382107 30.760994 30.760994 0.812093
524 17.381172 27.351587 27.351587 0.790295
525 17.568604 29.831865 29.831865 0.783787
526 17.461163 34.612133 34.612133 0.832215
527 17.314554 26.954687 26.954687 0.793182
528 17.483107 36.418304 36.418304 0.840229
529 17.310074 27.303513 27.303513 0.804458
530 17.327452 32.026627 32.026627 0.813015
531 17.321669 25.205114 25.205114 0.796766
532 17.192396 26.289732 26.289732 0.804069
533 17.361225 30.578878 30.578878 0.805598
534 17.260796 27.188999 27.188999 0.793238
535 17.303936 30.503344 30.503344 0.798804
536 17.307570 28.446278 28.446278 0.808474
537 17.246763 28.973631 28.973631 0.810209
538 17.247349 28.571320 28.571320 0.807079
539 17.199726 29.678768 29.678768 0.805977
540 17.208178 27.343475 27.343475 0.797696
541 17.080887 26.155279 26.155279 0.792784
542 17.105354 28.845409 28.845409 0.812772
543 17.183191 26.199879 26.199879 0.785627
544 17.099640 28.239481 28.239481 0.806001
545 17.122894 27.121138 27.121138 0.783365
546 17.191368 28.866413 28.866413 0.799720
547 17.174528 28.363388 28.363388 0.793286
548 17.141420 26.097416 26.097416 0.797240
549 17.219872 32.718220 32.718220 0.814210
550 17.072456 28.275219 28.275219 0.803153
551 17.203091 26.738821 26.738821 0.793523
552 17.089849 30.534100 30.534100 0.799119
553 17.178310 30.255531 30.255531 0.804183
554 17.007668 25.847328 25.847328 0.792709
555 17.178520 32.843349 32.843349 0.820920
556 16.983553 27.881701 27.881701 0.799397
557 17.079176 29.980043 29.980043 0.815065
558 17.071980 33.022240 33.022240 0.807456
559 17.049711 26.611809 26.611809 0.788952
560 17.131897 28.906601 28.906601 0.797059
561 17.051199 35.488522 35.488522 0.820983
562 17.138382 32.740532 32.740532 0.809891
563 17.273539 35.268623 35.268623 0.817666
564 17.163408 29.226507 29.226507 0.799978
565 17.175001 34.330811 34.330811 0.820431
566 17.057478 29.133484 29.133484 0.797177
567 16.912235 25.953484 25.953484 0.785920
568 16.994179 29.590284 29.590284 0.791411
569 17.040358 30.785744 30.785744 0.809470
570 17.007362 31.357382 31.357382 0.798695
571 17.104525 26.843304 26.843304 0.799187
572 16.939304 27.436810 27.436810 0.796603
573 17.098824 34.378216 34.378216 0.817264
574 16.954121 31.910265 31.910265 0.802028
575 16.983212 28.555937 28.555937 0.800918
576 17.155170 30.026846 30.026846 0.811817
577 17.242094 36.562836 36.562836 0.814757
578 17.060047 28.909149 28.909149 0.802118
579 16.938492 27.774906 27.774906 0.789041
580 16.952663 24.950550 24.950550 0.784761
581 16.997622 29.644119 29.644119 0.792056
582 17.040215 29.872421 29.872421 0.802756
583 16.964943 26.599747 26.599747 0.785250
584 17.105183 35.066528 35.066528 0.816200
585 17.029114 29.482920 29.482920 0.794696
586 16.967745 29.923212 29.923212 0.804836
587 17.068871 36.625725 36.625725 0.813744
588 17.066595 29.443745 29.443745 0.804752
589 17.046106 31.077038 31.077038 0.796184
590 17.081030 36.422691 36.422691 0.825810
591 16.980398 31.546888 31.546888 0.809645
592 16.957745 27.725594 27.725594 0.801067
593 16.934332 28.236074 28.236074 0.790647
594 17.016115 28.282209 28.282209 0.794521
595 17.015324 27.978348 27.978348 0.796241
596 17.026548 28.905725 28.905725 0.794053
597 17.034653 28.156298 28.156298 0.798595
598 17.076344 30.742912 30.742912 0.808406
599 16.990126 28.792730 28.792730 0.801815
600 16.951323 26.981087 26.981087 0.788165


In [129]:
l.recorder.plot_losses()



In [6]:
l.save(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")

In [7]:
!ls models


old_models
old_repr
r_speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False2.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_True.pth
tmp.pth

In [111]:
val_df = get_results_df(val_dl, l.model)
train_df = get_results_df(train_dl, l.model)

In [116]:
df = train_df

In [117]:
df[:][['prediction','target', 'abs_diff','APE']].describe()


Out[117]:
prediction target abs_diff APE
count 185651.000000 185651.000000 1.856510e+05 185651.000000
mean 1.159134 1.184810 1.515271e-01 16.350126
std 1.564008 1.613917 4.039645e-01 49.778847
min 0.010645 0.008491 2.384186e-07 0.000060
25% 0.216299 0.211613 1.029227e-02 3.166194
50% 0.569257 0.565550 3.776443e-02 7.030761
75% 1.389433 1.394052 1.252649e-01 13.793568
max 11.054943 16.089287 1.534635e+01 3100.805420

In [118]:
df = val_df

In [119]:
df[:][['prediction','target', 'abs_diff','APE']].describe()


Out[119]:
prediction target abs_diff APE
count 10000.000000 10000.000000 10000.000000 10000.000000
mean 1.717750 1.601825 0.545405 50.500225
std 1.808425 1.653595 0.805434 75.607666
min 0.012248 0.010685 0.000005 0.001589
25% 0.372329 0.301894 0.054618 8.127786
50% 1.037479 0.970127 0.214614 23.571514
75% 2.149681 2.176448 0.696176 61.788719
max 8.437875 7.522550 5.974429 654.431885

In [120]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()


Out[120]:
prediction target abs_diff APE
count 25.000000 25.0 25.000000 25.000000
mean 2.385324 1.0 1.438993 143.899323
std 1.898048 0.0 1.855988 185.598831
min 0.731159 1.0 0.021526 2.152550
25% 0.973027 1.0 0.069987 6.998700
50% 1.585838 1.0 0.585838 58.583771
75% 3.934247 1.0 2.934247 293.424683
max 6.791101 1.0 5.791101 579.110046

In [48]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()


Out[48]:
prediction target abs_diff APE
count 10908.000000 10908.000000 10908.000000 10908.000000
mean 1.300379 1.544865 0.348264 17.498978
std 1.134063 1.561985 0.874742 19.460262
min 0.040741 0.080195 0.000011 0.000417
25% 0.521132 0.561271 0.037481 4.959255
50% 0.952309 1.011548 0.103626 11.225842
75% 1.701679 1.967590 0.268848 22.333059
max 7.149683 16.089287 15.834622 476.782654

In [49]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()


Out[49]:
prediction target abs_diff APE
count 5388.000000 5388.000000 5388.000000 5388.000000
mean 4.646883 5.080168 0.709873 17.410728
std 2.274305 2.592474 0.847952 34.551174
min 0.041153 0.058317 0.000006 0.000110
25% 2.416280 2.779300 0.175089 4.599086
50% 5.243072 5.502153 0.416792 10.081613
75% 6.300526 6.526127 0.916621 17.963830
max 8.230931 13.560771 6.250570 510.743958

In [50]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()


Out[50]:
prediction target abs_diff APE
count 1758.000000 1758.000000 1758.000000 1758.000000
mean 1.041926 1.452486 0.451720 23.524038
std 1.040951 1.578704 0.813260 25.382376
min 0.041223 0.018637 0.000014 0.005628
25% 0.306825 0.347235 0.030223 6.690400
50% 0.681719 0.822524 0.115805 15.353576
75% 1.462868 2.103264 0.442162 31.986423
max 6.201118 9.472253 5.777220 405.513367

In [51]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()


Out[51]:
prediction target abs_diff APE
count 21395.000000 21395.000000 21395.000000 21395.000000
mean 1.103823 1.345353 0.325749 19.157253
std 1.122763 1.548089 0.875139 22.390285
min 0.037974 0.029894 0.000007 0.000545
25% 0.387534 0.399394 0.031919 5.276947
50% 0.853066 0.889977 0.086671 11.755613
75% 1.368872 1.568050 0.212368 23.404687
max 8.301966 13.433655 12.628229 364.638855

In [52]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()


Out[52]:
prediction target abs_diff APE
count 26827.000000 26827.000000 26827.000000 26827.000000
mean 2.278949 2.478805 0.312914 13.936263
std 1.690642 1.895653 0.552439 35.758522
min 0.041220 0.018075 0.000003 0.000124
25% 0.873289 0.883696 0.049614 3.965816
50% 1.820674 1.956570 0.140420 8.129169
75% 3.415022 3.795769 0.346526 14.133539
max 7.781519 13.558847 6.984952 659.840881

In [53]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()


Out[53]:
prediction target abs_diff APE
count 38476.000000 38476.000000 38476.000000 38476.000000
mean 0.599790 0.718080 0.161907 16.824739
std 0.726942 0.959410 0.441344 19.137383
min 0.011075 0.008774 0.000001 0.002214
25% 0.159066 0.163858 0.010282 4.783708
50% 0.350566 0.374172 0.037043 10.798526
75% 0.797795 0.897981 0.118515 21.021241
max 6.714707 9.643795 7.808821 224.782593

In [54]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()


Out[54]:
prediction target abs_diff APE
count 80352.000000 80352.000000 8.035200e+04 80352.000000
mean 0.502504 0.618813 1.499550e-01 16.857380
std 0.697265 0.935955 4.647943e-01 20.027716
min 0.011103 0.008491 2.942979e-07 0.000142
25% 0.118478 0.126692 8.149397e-03 4.663052
50% 0.289729 0.305125 2.876675e-02 10.372399
75% 0.639897 0.745241 9.044305e-02 20.579370
max 7.473004 12.004527 8.477986e+00 453.119904

In [56]:
df[(df.interchange + df.tile + df.unroll != 0)][['prediction','target', 'abs_diff','APE']].describe()


Out[56]:
prediction target abs_diff APE
count 185104.000000 185104.000000 1.851040e+05 185104.000000
mean 1.022474 1.185355 2.272181e-01 16.810253
std 1.374168 1.616282 5.942063e-01 23.641901
min 0.011075 0.008491 2.942979e-07 0.000110
25% 0.197577 0.210773 1.411321e-02 4.647739
50% 0.502107 0.562228 5.195722e-02 10.257700
75% 1.149322 1.398968 1.731274e-01 19.895230
max 8.301966 16.089287 1.583462e+01 659.840881

In [76]:
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange + df.tile + df.unroll != 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df2 = df
joint_plot(df2, f"Validation dataset, {loss_func} loss")



In [ ]: