In [37]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [40, 30]

In [1]:
import fastai as fai
from fastai.basic_data import DataLoader
from data_loader import *
import numpy as np
from torch.utils.data import SubsetRandomSampler
from model import *
from model_bn import *
from torch import optim
import dill
import papermill as pm
from main import *
import seaborn as sns
import pandas as pd

In [2]:
optimizer = 'Adam'
num_workers=8
maxsize=100000
batch_size=2048
n_epochs=500
batch_norm = True
dataset='data/speedup_dataset.h5'

In [3]:
train_dl, val_dl = train_dev_split(dataset, batch_size, num_workers, maxsize)

db = fai.basic_data.DataBunch(train_dl, val_dl)

In [4]:
def criterion(inputs, targets):
    eps = 1e-5
    return torch.mean(torch.abs(targets - inputs)/(targets+eps)*100)

In [5]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]

layers_sizes = [300, 200, 120, 80, 30]

model = None 

if batch_norm:
    model = Model_BN(input_size, output_size, hidden_sizes=layers_sizes)
else:
    model = Model(input_size, output_size)
    
#criterion = nn.MSELoss()

l = fai.Learner(db, model, loss_func=criterion)

if optimizer == 'SGD':
    l.opt_func = optim.SGD

In [44]:
print(l.model)


Model_BN(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=381, out_features=300, bias=False)
    (1): Linear(in_features=300, out_features=200, bias=False)
    (2): Linear(in_features=200, out_features=120, bias=False)
    (3): Linear(in_features=120, out_features=80, bias=False)
    (4): Linear(in_features=80, out_features=30, bias=False)
  )
  (batch_norm_layers): ModuleList(
    (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (predict): Linear(in_features=30, out_features=1, bias=True)
)

In [116]:
l.lr_find()


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

In [117]:
l.recorder.plot()



In [136]:
lr =1e-5

In [137]:
l.fit_one_cycle(1200, lr)


Total time: 2:32:36

epoch train_loss valid_loss
1 43.308823 42.511745
2 43.440548 42.512672
3 43.522625 42.832088
4 43.359398 42.508614
5 43.424362 42.823555
6 43.302296 42.570656
7 43.337223 42.466236
8 43.198784 42.481415
9 43.276703 42.472767
10 43.393902 42.904263
11 43.391769 42.667290
12 43.403667 42.497829
13 43.380348 42.411304
14 43.314499 42.548527
15 43.428509 42.706577
16 43.367168 42.650780
17 43.343029 42.677246
18 43.319000 42.489429
19 43.391106 42.656219
20 43.388210 42.591579
21 43.298534 42.557522
22 43.390308 42.564564
23 43.285198 42.421730
24 43.337765 42.498516
25 43.398766 42.607731
26 43.432446 42.581360
27 43.537910 42.634064
28 43.363823 42.551136
29 43.272110 42.537769
30 43.310291 42.701267
31 43.352707 42.672440
32 43.470226 42.658062
33 43.471638 43.215084
34 43.381989 42.497452
35 43.275143 42.465843
36 43.454556 42.537891
37 43.342133 42.553917
38 43.491207 42.599930
39 43.492729 42.788162
40 43.393791 42.519951
41 43.326050 42.508144
42 43.264702 42.469883
43 43.262543 42.562611
44 43.189724 42.455605
45 43.508636 43.351971
46 43.469368 42.524853
47 43.415024 42.548290
48 43.464169 42.469677
49 43.532379 42.680756
50 43.446045 42.569805
51 43.244213 42.418137
52 43.348724 42.588364
53 43.323513 42.531250
54 43.429993 42.579483
55 43.407803 42.584702
56 43.445045 42.670910
57 43.460434 42.590237
58 43.447296 42.722595
59 43.325001 42.600826
60 43.355171 42.485855
61 43.348545 42.517017
62 43.279091 42.603905
63 43.369083 42.579792
64 43.432575 42.719604
65 43.240456 42.461514
66 43.270180 42.428192
67 43.428951 42.719746
68 43.407288 42.508919
69 43.487976 42.704956
70 43.427967 42.735149
71 43.384842 42.591808
72 43.280643 42.450020
73 43.363865 42.864948
74 43.431782 42.713253
75 43.452579 42.648865
76 43.370678 42.467739
77 43.367920 42.591820
78 43.376579 42.467312
79 43.279800 42.487236
80 43.280773 42.534298
81 43.337933 42.719807
82 43.303379 42.480621
83 43.540298 42.723038
84 43.448555 42.610886
85 43.321148 42.382626
86 43.288475 42.521107
87 43.310978 42.427227
88 43.354904 42.763062
89 43.357090 42.701160
90 43.235104 42.497105
91 43.286869 42.386902
92 43.400986 42.518703
93 43.344879 42.422607
94 43.313339 42.456898
95 43.344193 42.705093
96 43.289135 42.512676
97 43.295864 42.409733
98 43.399303 42.475636
99 43.362911 42.493778
100 43.316734 42.766773
101 43.316734 42.588840
102 43.397621 42.586967
103 43.487553 42.623798
104 43.426090 42.680416
105 43.335175 42.450436
106 43.267071 42.332695
107 43.291157 42.436569
108 43.248577 42.464687
109 43.265816 42.339283
110 43.259209 42.500687
111 43.165733 42.683731
112 43.152603 42.397896
113 43.320751 42.639519
114 43.263523 42.415813
115 43.270466 42.813183
116 43.334484 42.662815
117 43.389759 42.488380
118 43.325508 42.603394
119 43.361469 42.493492
120 43.335346 42.578968
121 43.294811 42.548058
122 43.222107 42.366676
123 43.160793 42.383427
124 43.141220 42.593845
125 43.309528 42.506844
126 43.144165 42.516418
127 43.284153 42.507442
128 43.323742 42.441414
129 43.351765 42.560127
130 43.193291 42.463081
131 43.429253 42.546333
132 43.271992 42.470543
133 43.363823 42.749699
134 43.237473 42.513920
135 43.264339 42.445187
136 43.331261 42.561455
137 43.312698 42.553104
138 43.244049 42.574196
139 43.318817 42.570320
140 43.194126 42.338051
141 43.211468 42.325531
142 43.285694 42.543560
143 43.373501 42.735111
144 43.335075 42.575508
145 43.193237 42.311756
146 43.247868 42.399548
147 43.201267 42.263844
148 43.297466 42.382526
149 43.224606 42.407101
150 43.180817 42.357018
151 43.379089 42.760448
152 43.277565 42.502728
153 43.373203 42.504719
154 43.370510 42.435719
155 43.421436 42.638718
156 43.295113 42.404251
157 43.295193 42.432026
158 43.393440 42.474426
159 43.315281 42.686474
160 43.289707 42.407482
161 43.250385 42.288689
162 43.222820 42.413395
163 43.211552 42.366455
164 43.263912 42.659931
165 43.315254 42.561066
166 43.311863 42.646599
167 43.198475 42.334595
168 43.190556 42.534267
169 43.329521 42.330570
170 43.321499 42.354347
171 43.227497 42.327354
172 43.335384 42.516888
173 43.269951 42.355476
174 43.178883 42.432110
175 43.089993 42.195763
176 43.221603 42.362789
177 43.184181 42.476276
178 43.176987 42.400883
179 43.170002 42.517239
180 43.141499 42.437321
181 43.233356 42.316410
182 43.172886 42.436558
183 43.143227 42.500553
184 43.285915 42.745098
185 43.280193 42.496872
186 43.258728 42.520077
187 43.211422 42.226952
188 43.184296 42.244671
189 43.349567 42.674175
190 43.151520 42.397049
191 43.199123 42.401440
192 43.136684 42.312008
193 43.166206 42.420994
194 43.263863 42.250496
195 43.133865 42.286366
196 43.206516 42.311127
197 43.265869 42.431080
198 43.268402 42.437626
199 43.119148 42.320129
200 43.143574 42.650581
201 43.257111 42.446098
202 43.112968 42.259274
203 43.098816 42.234215
204 43.317780 43.193264
205 43.265362 42.595005
206 43.120068 42.211823
207 43.050522 42.187096
208 43.137661 42.296898
209 43.038986 42.228149
210 43.204395 42.334774
211 43.156063 42.152214
212 43.167034 42.443050
213 43.014675 42.193607
214 43.183807 42.403728
215 43.211624 42.190807
216 43.130669 42.245991
217 43.168030 42.448418
218 43.108006 42.431931
219 43.096741 42.375481
220 43.227245 42.572723
221 43.045704 42.201012
222 43.024002 42.356026
223 43.080677 42.380455
224 42.980579 42.152283
225 43.053883 42.308189
226 42.940228 42.161774
227 43.005127 42.188911
228 43.033890 42.173935
229 42.992813 42.585255
230 42.999443 42.258640
231 42.961220 42.297909
232 42.917561 42.176254
233 43.008018 42.334019
234 43.051231 42.351738
235 43.018929 42.188313
236 42.975731 42.186150
237 42.963436 42.153095
238 43.129246 42.566322
239 43.091927 42.312054
240 43.121372 42.086952
241 43.029343 42.135124
242 43.044067 42.219898
243 43.012306 42.122932
244 43.014561 42.322819
245 43.053043 42.117138
246 43.017601 42.223198
247 43.045921 42.410545
248 43.019180 42.228680
249 43.002071 42.257088
250 43.103962 42.381157
251 42.967606 42.110027
252 42.945602 42.112461
253 42.846420 41.988468
254 42.916729 42.042744
255 43.028671 42.331482
256 42.947941 42.170731
257 42.880531 42.169357
258 42.901554 42.307484
259 43.014194 42.091679
260 42.970135 42.197788
261 43.011520 42.168026
262 43.005985 42.154202
263 43.004566 42.135094
264 42.884888 41.997787
265 42.831566 42.002308
266 42.959160 42.288376
267 42.924625 42.021477
268 43.017998 42.208313
269 43.120888 42.410500
270 43.162666 42.276325
271 43.073608 41.938839
272 42.901661 41.855385
273 42.842232 42.272804
274 42.882057 41.952637
275 42.916054 42.177704
276 42.792236 42.238007
277 42.890148 42.084408
278 42.828568 42.205986
279 42.811569 42.049580
280 43.243748 42.488880
281 43.065220 42.080807
282 42.888756 42.077328
283 42.845760 41.826656
284 42.935131 42.148705
285 42.760319 42.120308
286 42.879009 42.045528
287 42.862892 41.905067
288 42.982159 42.151749
289 43.023952 42.133980
290 43.071560 42.426815
291 42.938454 42.044838
292 42.861984 42.103294
293 43.003902 42.481201
294 42.831276 42.051811
295 42.980816 42.079124
296 42.809906 41.972889
297 42.792458 42.026894
298 42.810532 42.067600
299 42.822411 42.184246
300 42.779461 41.771545
301 42.880264 42.175171
302 42.844345 41.856579
303 42.655869 41.949787
304 42.701820 41.751923
305 42.678684 41.779453
306 42.744370 41.835495
307 42.813496 42.149643
308 42.879421 42.384701
309 43.057240 42.097874
310 42.972980 41.883121
311 42.877312 42.190792
312 43.047871 42.169724
313 42.931076 41.951782
314 42.758377 41.773201
315 42.701591 41.921288
316 42.953884 42.152615
317 42.756279 41.882862
318 42.677769 41.713383
319 42.578537 41.687145
320 42.628929 41.960384
321 42.529114 41.883400
322 42.572681 41.885738
323 42.579865 41.658867
324 42.808060 42.204353
325 42.771034 42.056339
326 42.570560 41.774673
327 42.529305 41.784626
328 42.615219 41.629383
329 42.660172 42.022442
330 42.852463 42.097675
331 42.759567 42.097233
332 42.576832 41.618286
333 42.746090 42.166943
334 42.922039 42.154999
335 43.016827 42.046646
336 42.746376 41.685814
337 42.636288 41.800156
338 42.586994 42.034973
339 42.475388 41.628414
340 42.426167 41.605614
341 42.472683 41.705425
342 42.565437 42.146664
343 42.488110 41.637955
344 42.781971 42.115387
345 42.607430 41.648045
346 42.512070 41.692696
347 42.520882 41.668137
348 42.613682 41.942402
349 42.654659 42.079990
350 42.608181 41.661129
351 42.572063 41.775986
352 42.537857 41.571064
353 42.544552 41.509159
354 42.469330 41.781708
355 42.551334 41.602943
356 42.440598 41.464443
357 42.496616 41.676586
358 42.609161 41.565777
359 42.489223 41.349514
360 42.752827 41.785976
361 42.677479 41.707218
362 42.543865 41.716900
363 42.322128 41.387867
364 42.452339 41.645294
365 42.441360 41.609970
366 42.490810 41.658108
367 42.725384 42.023575
368 42.652695 41.678944
369 42.660336 41.903721
370 42.630302 41.843071
371 42.478783 41.507824
372 42.419155 41.640976
373 42.410950 41.892738
374 42.327682 41.737839
375 42.599144 41.745640
376 42.454029 41.552967
377 42.417725 41.545242
378 42.400539 41.380894
379 42.446865 41.659908
380 42.384380 41.816437
381 42.403580 41.594154
382 42.443386 41.938171
383 42.359291 41.371670
384 42.537037 41.883488
385 42.516636 41.500137
386 42.471733 41.739201
387 42.418522 41.567677
388 42.439098 41.529274
389 42.334743 41.361462
390 42.333385 41.571407
391 42.385750 41.863811
392 42.593742 41.712864
393 42.454536 41.454147
394 42.563129 41.970089
395 42.420856 41.415218
396 42.879295 42.796795
397 42.606812 41.544258
398 42.390259 41.528698
399 42.222370 41.487305
400 42.263580 41.282112
401 42.267460 41.468239
402 42.326324 41.465820
403 42.563320 42.024467
404 42.418571 41.707989
405 42.355164 41.242661
406 42.408871 42.003685
407 42.207825 41.221474
408 42.235348 41.318092
409 42.317959 41.438202
410 42.319637 41.463772
411 42.231670 41.405922
412 42.289207 41.522751
413 42.288872 41.409077
414 42.088970 41.223618
415 42.208500 41.662086
416 42.193970 41.502048
417 42.024624 41.080818
418 42.127243 41.360786
419 42.162781 41.379070
420 42.239967 41.522907
421 42.289818 41.674026
422 42.249401 41.432499
423 42.369846 41.651722
424 42.364357 41.472591
425 42.301571 41.310585
426 42.153515 41.232483
427 42.105740 41.359154
428 42.209255 41.805950
429 42.111828 41.088322
430 42.068062 41.191013
431 42.147125 41.601315
432 42.111267 41.370411
433 42.095078 41.381992
434 42.067917 41.487892
435 42.128036 41.622784
436 42.246876 41.394268
437 42.245373 41.395126
438 42.122765 41.258671
439 42.125824 41.771034
440 42.176102 41.325230
441 42.022976 40.988430
442 42.081821 41.394405
443 42.161446 41.286739
444 42.173492 41.373737
445 42.021744 41.160484
446 42.121262 41.319965
447 42.031612 41.067039
448 42.013744 41.168812
449 41.943787 41.299942
450 42.122696 41.293865
451 42.467323 41.651047
452 42.164436 41.164673
453 42.106731 41.083660
454 42.059650 40.991531
455 42.176949 41.426384
456 42.081562 41.259350
457 42.024689 41.449917
458 42.030643 41.020264
459 42.049995 41.241283
460 42.142960 41.205929
461 41.910553 41.114620
462 41.992405 41.124996
463 41.894688 40.983395
464 42.061298 41.389393
465 42.227573 41.479839
466 42.175461 41.941811
467 41.915447 40.878674
468 41.974064 41.639580
469 42.071255 41.349049
470 42.170204 41.482292
471 42.133923 41.060043
472 42.177132 41.067936
473 42.152084 41.551014
474 42.111397 41.212067
475 41.949821 41.019821
476 41.921036 41.098701
477 41.805679 41.107639
478 41.991096 41.479622
479 41.873230 41.066296
480 41.831280 41.123425
481 41.962620 41.033226
482 41.959923 41.240871
483 41.988945 41.180145
484 41.910683 41.048920
485 41.930115 41.171219
486 42.155830 41.618214
487 42.090412 41.340714
488 41.926365 41.035114
489 42.279736 41.808609
490 42.080589 41.016533
491 42.010029 41.527462
492 42.094784 41.202850
493 42.103680 41.707260
494 42.018597 40.870651
495 41.992836 41.589737
496 41.990936 41.330601
497 41.907230 41.012062
498 41.857666 41.258148
499 41.950981 41.270409
500 41.977772 41.467751
501 41.906109 41.173149
502 41.887047 41.408421
503 41.686737 40.788727
504 41.862286 41.406509
505 41.980446 41.493587
506 41.827488 40.919365
507 41.908943 41.143677
508 41.886524 41.042088
509 41.731739 40.883080
510 41.879910 40.888866
511 41.813515 41.156849
512 41.857796 41.087631
513 41.877563 41.263420
514 41.826733 40.969051
515 41.935905 41.169125
516 41.787067 41.119232
517 41.815952 41.275387
518 41.656502 40.708199
519 41.705406 40.862274
520 41.847084 41.199581
521 41.836750 40.818424
522 41.875713 41.300861
523 41.773254 40.775017
524 41.708744 40.929337
525 41.753094 40.971260
526 41.809929 41.016682
527 41.731800 40.812443
528 41.876511 40.776569
529 41.832520 41.218956
530 41.798237 41.098492
531 41.869129 40.866913
532 41.670387 40.917320
533 41.558941 40.732033
534 41.584198 40.888977
535 41.646076 40.699394
536 41.825211 41.314728
537 41.877888 41.209019
538 41.874569 40.779682
539 41.790394 41.036293
540 41.678162 40.679352
541 41.827297 41.073544
542 41.860817 41.381279
543 41.730244 40.923752
544 41.715946 40.956726
545 41.588318 40.659569
546 41.657391 41.284939
547 41.635639 40.834858
548 41.777176 40.736141
549 41.627342 40.674606
550 41.631546 40.893482
551 41.496197 40.710480
552 41.565506 41.486031
553 41.642769 41.107445
554 41.800079 41.078568
555 41.532654 40.545906
556 41.624531 41.028530
557 41.684696 40.934799
558 41.612522 40.689919
559 41.571518 40.780586
560 41.539387 40.757141
561 41.518322 40.659340
562 41.569038 40.936302
563 41.486897 40.640961
564 41.430187 40.707344
565 41.378338 40.546589
566 41.599056 41.007545
567 41.553341 40.736259
568 41.564583 41.023861
569 41.662045 40.648735
570 41.568188 40.632805
571 41.521824 40.756264
572 41.604736 40.992992
573 41.604275 40.914330
574 41.381992 40.543228
575 41.491383 41.032055
576 41.395073 40.678730
577 41.531811 40.694538
578 41.648514 41.222237
579 41.531082 40.667088
580 41.437744 40.539490
581 41.495621 40.520508
582 41.512562 40.566673
583 41.542316 40.656567
584 41.661018 41.384239
585 41.531498 40.707802
586 41.604176 40.570213
587 41.402355 40.556934
588 41.500256 40.995655
589 41.409679 40.516075
590 41.418301 40.655670
591 41.493027 40.445225
592 41.672611 41.460804
593 41.607914 40.939213
594 41.697895 41.253914
595 41.482330 40.809391
596 41.303638 40.486267
597 41.318302 40.648708
598 41.429344 40.570374
599 41.366116 40.649960
600 41.401955 40.708874
601 41.520546 40.689365
602 41.548271 41.571499
603 41.515591 40.744755
604 41.444080 40.479248
605 41.337673 40.466507
606 41.365715 40.729218
607 41.626705 41.067101
608 41.394535 40.419418
609 41.345085 40.508076
610 41.424660 40.940689
611 41.424084 40.622986
612 41.433434 40.528633
613 41.535995 40.617268
614 41.560452 40.488663
615 41.459953 40.465042
616 41.549667 40.667744
617 41.405312 40.801685
618 41.304848 40.441494
619 41.367802 40.485817
620 41.491177 40.628551
621 41.549614 40.734730
622 41.433884 40.462082
623 41.593071 41.021183
624 41.686382 40.952255
625 41.644787 40.702011
626 41.346592 40.492943
627 41.300320 40.398438
628 41.330936 40.805431
629 41.313736 40.456326
630 41.471794 40.817657
631 41.313923 40.473148
632 41.273590 40.498230
633 41.191666 40.406803
634 41.139400 40.737961
635 41.333096 40.849392
636 41.291962 40.756123
637 41.275272 40.473339
638 41.309368 40.837959
639 41.361656 40.640419
640 41.385620 40.486992
641 41.507359 41.046646
642 41.451199 40.665524
643 41.437210 40.470062
644 41.351738 40.841770
645 41.327591 40.784107
646 41.319489 40.499355
647 41.370430 40.459496
648 41.218079 40.390503
649 41.210152 40.357952
650 41.219578 40.524673
651 41.202190 40.282078
652 41.288521 40.614742
653 41.077396 40.189632
654 41.194698 40.652267
655 41.119354 40.231861
656 41.231789 40.615288
657 41.355175 40.359596
658 41.276402 40.235004
659 41.639481 41.093361
660 41.460621 40.685638
661 41.344440 40.374729
662 41.195580 40.403080
663 41.152054 40.261646
664 41.048531 40.121506
665 41.154869 40.312317
666 41.304161 40.715458
667 41.166290 40.214252
668 41.284626 40.363964
669 41.340397 40.522205
670 41.290802 40.421261
671 41.378082 40.558651
672 41.321747 40.596020
673 41.366146 40.747093
674 41.243782 40.486820
675 41.185677 40.374935
676 41.140579 40.521584
677 41.249638 40.657227
678 41.258457 40.267982
679 41.360931 40.932064
680 41.170448 40.265610
681 41.058727 40.357063
682 41.281151 41.072071
683 41.151405 40.766479
684 41.176552 40.171219
685 41.214455 40.509598
686 41.161293 40.298279
687 41.183250 40.214439
688 41.060959 40.520542
689 40.984051 40.378044
690 40.988049 40.232063
691 41.032921 40.665844
692 41.157627 40.345905
693 41.197033 40.318069
694 41.074337 40.127693
695 41.252502 40.383270
696 41.067722 40.366169
697 41.070461 40.564323
698 41.082527 40.214256
699 41.159801 40.227875
700 41.136448 40.270336
701 41.096451 40.226452
702 41.209103 40.380409
703 41.131569 40.243477
704 41.003620 40.051167
705 41.166714 40.572136
706 41.033661 40.303261
707 40.975845 39.971012
708 40.981819 40.063187
709 40.925880 40.112511
710 41.067966 40.683857
711 41.086628 40.367264
712 40.912136 40.082050
713 41.063393 40.407467
714 41.054035 40.265480
715 40.924576 40.178741
716 40.946564 40.182468
717 41.023506 40.353886
718 41.071487 40.420639
719 41.115040 40.309917
720 41.014431 40.474224
721 40.942276 40.043350
722 40.886715 40.085697
723 40.936646 40.060089
724 40.975372 40.077240
725 40.967655 40.656929
726 41.018806 40.343430
727 40.996471 40.275791
728 40.933498 40.224888
729 41.185730 41.394409
730 41.017967 39.964413
731 40.900627 40.621731
732 41.085785 40.420341
733 41.220982 40.448334
734 41.230083 40.103645
735 41.085861 40.142178
736 40.889782 40.201962
737 40.834663 40.403458
738 40.891838 40.693310
739 40.997715 40.297123
740 40.916607 40.238029
741 41.089497 40.268532
742 40.975601 40.186878
743 40.971169 40.369740
744 40.998337 39.976585
745 41.029079 40.195789
746 41.036694 40.374249
747 40.853886 39.979923
748 41.036224 41.102417
749 40.761005 40.015652
750 40.937458 40.373241
751 41.052063 40.220867
752 41.040680 40.737900
753 41.008465 40.819061
754 40.963085 40.209064
755 40.906162 40.123703
756 40.956936 40.376244
757 40.833447 40.136555
758 40.922260 40.071774
759 40.789394 40.114037
760 40.720764 40.011383
761 40.814255 40.178722
762 40.811016 40.074051
763 40.933746 40.486237
764 40.831131 40.016827
765 41.042549 41.079002
766 40.843761 39.904072
767 40.995380 40.416908
768 40.804173 40.005966
769 40.880569 40.051498
770 40.912930 40.050156
771 40.997143 40.335690
772 40.905266 40.051048
773 41.004200 40.618263
774 40.978661 40.312965
775 40.911121 40.455830
776 40.934414 39.908123
777 40.916431 40.018711
778 40.868862 40.035873
779 40.885670 40.418522
780 40.816441 39.980751
781 40.831200 40.423901
782 40.853802 40.103184
783 40.887348 40.341496
784 40.981075 40.406944
785 41.011086 40.304764
786 41.017017 40.588020
787 40.916386 40.085567
788 40.961414 40.277885
789 40.971256 40.065929
790 40.792004 39.989799
791 40.959461 40.285873
792 40.913250 40.565445
793 40.856689 40.040825
794 40.883919 40.298317
795 40.847378 39.931561
796 40.807690 39.921162
797 40.957062 40.206436
798 40.914886 40.175301
799 40.853085 39.775169
800 40.940441 40.161663
801 41.023335 40.231773
802 40.931759 40.045860
803 40.911732 39.979954
804 40.783710 39.930870
805 40.758572 39.909874
806 40.792576 40.040386
807 40.732674 39.860939
808 40.754860 39.982777
809 40.749760 39.906776
810 40.921444 40.970131
811 40.995865 40.865070
812 40.854610 40.027294
813 40.912228 40.486668
814 41.020950 40.134705
815 41.020638 40.232574
816 40.979534 40.663425
817 40.939857 40.470528
818 40.854759 40.057556
819 40.962372 40.234436
820 40.886402 40.519402
821 40.858433 40.075470
822 40.709923 39.929611
823 40.747318 40.040264
824 40.854713 40.244896
825 40.821587 39.987995
826 40.797852 39.919605
827 40.830460 40.380699
828 40.786392 39.852871
829 41.035145 40.107105
830 40.860668 39.845684
831 40.720524 39.894104
832 40.607357 39.679382
833 40.684887 39.846333
834 40.778606 40.394276
835 40.765900 39.939323
836 40.789360 40.130211
837 40.853348 40.483345
838 40.796124 39.814774
839 41.033669 41.100956
840 40.895020 40.389412
841 40.917355 40.565418
842 40.889256 40.017387
843 40.891983 40.241737
844 40.805092 40.372475
845 40.851681 40.248898
846 40.896931 40.892208
847 40.673424 39.977539
848 40.707779 40.055683
849 40.811199 40.334545
850 40.761631 39.748600
851 40.663139 39.922260
852 40.643574 39.794044
853 40.664204 40.014778
854 40.592197 39.708874
855 40.816669 40.455219
856 40.720055 39.844826
857 40.718552 39.940586
858 40.613964 39.866539
859 40.646267 39.843620
860 40.720890 40.224220
861 40.651451 39.945259
862 40.639641 39.675095
863 40.624245 39.782650
864 40.543926 39.642170
865 40.714779 40.279339
866 40.869873 40.452320
867 40.779217 39.903790
868 40.693844 40.072952
869 40.826004 40.404305
870 40.914467 40.210590
871 40.910683 40.033123
872 40.772545 39.731583
873 40.677059 39.944336
874 40.656075 39.843483
875 40.701370 39.970848
876 40.755684 40.430878
877 40.647034 39.719284
878 40.667519 40.173336
879 40.951550 40.459000
880 40.839649 39.846504
881 40.720222 39.831676
882 40.675961 39.890724
883 40.711239 39.744167
884 40.802536 40.458336
885 40.806564 40.309338
886 40.696140 39.750095
887 40.701099 39.775692
888 40.797997 40.449238
889 40.709328 39.882046
890 40.628433 40.026527
891 40.640503 39.730396
892 40.561245 39.975475
893 40.685043 40.187569
894 40.755295 39.989983
895 40.766693 40.635258
896 40.877121 39.987415
897 40.668056 39.940529
898 40.667946 39.888474
899 40.719513 39.822861
900 40.809032 39.923912
901 40.733295 40.116039
902 40.621716 39.949795
903 40.505562 39.725430
904 40.517445 39.866833
905 40.645538 39.874489
906 40.560749 39.623547
907 40.487881 40.043427
908 40.598515 40.013248
909 40.674671 39.987736
910 40.744576 39.776794
911 40.644955 39.774708
912 40.619812 40.406536
913 40.616779 39.846989
914 40.616253 39.874649
915 40.611301 40.175484
916 40.653305 40.090237
917 40.657349 39.597282
918 40.546528 39.802139
919 40.499569 39.872402
920 40.575958 39.874870
921 40.696377 40.103294
922 40.649117 39.918270
923 40.522972 39.600502
924 40.577835 40.112892
925 40.633358 39.994263
926 40.644714 40.010845
927 40.735760 40.165901
928 40.720463 40.255413
929 40.792141 40.219570
930 40.683563 40.170044
931 40.598984 39.904247
932 40.555508 39.931896
933 40.514114 39.786411
934 40.450920 39.605419
935 40.534355 39.626324
936 40.578091 39.945599
937 40.661781 39.832367
938 40.558086 40.150303
939 40.729145 40.869461
940 40.639793 39.808434
941 40.489254 39.636040
942 40.589981 40.369347
943 40.559216 39.810024
944 40.680058 39.892796
945 40.665073 39.821957
946 40.636124 39.763386
947 40.626278 39.673855
948 40.596588 40.002537
949 40.498104 39.673767
950 40.485558 40.115692
951 40.430492 39.663311
952 40.580757 39.736725
953 40.636387 40.064297
954 40.568398 39.910618
955 40.525951 39.909176
956 40.726250 39.986588
957 40.698860 40.060364
958 40.779949 40.155952
959 40.774429 39.980026
960 40.569839 39.726318
961 40.557690 39.821682
962 40.569077 39.892704
963 40.497444 39.667789
964 40.440483 39.623306
965 40.486912 39.751240
966 40.524788 39.752502
967 40.465141 39.677811
968 40.475712 39.677692
969 40.541534 39.843723
970 40.567284 39.868816
971 40.564899 39.756290
972 40.536880 39.690598
973 40.383530 39.671398
974 40.388718 39.674389
975 40.453445 39.920189
976 40.500866 39.818985
977 40.500160 40.493385
978 40.463108 39.753830
979 40.390816 39.838242
980 40.498051 40.046543
981 40.560879 40.093803
982 40.514053 40.029129
983 40.593937 40.488987
984 40.501015 39.599293
985 40.445526 39.805458
986 40.511055 39.800167
987 40.518280 39.735714
988 40.412918 39.525146
989 40.602623 40.099434
990 40.613029 40.253044
991 40.564087 39.914024
992 40.535580 39.601521
993 40.572948 40.043877
994 40.684002 39.950649
995 40.539352 39.803951
996 40.483269 39.676365
997 40.624840 39.573715
998 40.630455 40.416370
999 40.677254 40.898376
1000 40.486931 39.687836
1001 40.527069 39.633186
1002 40.474026 39.606518
1003 40.500767 39.563538
1004 40.485863 40.457455
1005 40.459488 39.804401
1006 40.507988 39.729321
1007 40.538780 40.188988
1008 40.561661 39.704449
1009 40.533539 40.208008
1010 40.476673 39.585136
1011 40.497005 40.391151
1012 40.471378 39.857773
1013 40.729794 40.056480
1014 40.483036 39.558571
1015 40.507442 39.715366
1016 40.643284 40.431927
1017 40.542389 40.362988
1018 40.531849 40.393620
1019 40.623852 40.329178
1020 40.538666 39.800465
1021 40.497059 39.715302
1022 40.617298 39.965569
1023 40.517368 39.815086
1024 40.464912 39.873730
1025 40.424934 39.619011
1026 40.442238 40.082062
1027 40.506176 39.910339
1028 40.538811 39.917976
1029 40.452213 39.594452
1030 40.449898 39.832100
1031 40.538639 39.809990
1032 40.468521 39.660194
1033 40.581989 40.132023
1034 40.480812 39.645756
1035 40.521618 39.545223
1036 40.509872 39.884480
1037 40.571037 39.678276
1038 40.575821 40.458538
1039 40.587395 39.857899
1040 40.403427 39.600536
1041 40.424961 40.102814
1042 40.327824 39.776329
1043 40.405025 39.663013
1044 40.545185 39.715286
1045 40.433685 40.063530
1046 40.493069 39.633972
1047 40.522240 39.692543
1048 40.499191 39.971050
1049 40.565495 40.043434
1050 40.490971 39.503593
1051 40.394646 39.452713
1052 40.505394 39.930416
1053 40.457630 39.584251
1054 40.524063 40.906895
1055 40.381950 39.533649
1056 40.447968 39.767536
1057 40.485237 40.004574
1058 40.591839 40.087379
1059 40.569870 40.261402
1060 40.489723 39.738266
1061 40.394268 39.961235
1062 40.308933 39.570168
1063 40.439533 40.153099
1064 40.386005 40.076324
1065 40.464310 39.864174
1066 40.455135 39.433292
1067 40.483479 39.604038
1068 40.376648 39.542095
1069 40.354000 39.770676
1070 40.407146 39.511738
1071 40.490440 39.923729
1072 40.420464 39.641220
1073 40.410698 39.760014
1074 40.379402 39.490726
1075 40.686226 40.676811
1076 40.502636 39.837307
1077 40.486832 39.803295
1078 40.520935 39.821896
1079 40.576923 39.577606
1080 40.440536 40.169582
1081 40.517105 40.211643
1082 40.534313 39.955830
1083 40.552937 39.955944
1084 40.496479 39.687641
1085 40.425030 39.579224
1086 40.378483 39.725609
1087 40.428391 39.508148
1088 40.394848 39.820133
1089 40.409531 39.744545
1090 40.401943 39.766262
1091 40.380562 39.655956
1092 40.461456 39.853828
1093 40.645985 40.137665
1094 40.498913 40.011494
1095 40.472145 40.146420
1096 40.364918 39.754036
1097 40.545738 40.523785
1098 40.641361 40.056324
1099 40.392792 39.549179
1100 40.432205 39.489304
1101 40.433895 39.814709
1102 40.520691 39.927532
1103 40.701027 40.244080
1104 40.543392 39.894527
1105 40.545578 39.996243
1106 40.487587 39.714054
1107 40.467247 40.363487
1108 40.452919 40.516113
1109 40.510250 39.931614
1110 40.523323 39.992432
1111 40.542763 40.571030
1112 40.423470 40.211361
1113 40.438393 39.843109
1114 40.543892 39.796761
1115 40.478603 39.722286
1116 40.475811 40.108189
1117 40.583199 40.107273
1118 40.468925 40.136986
1119 40.609550 40.372501
1120 40.371319 39.565388
1121 40.480164 40.028801
1122 40.382301 39.773926
1123 40.428699 39.981255
1124 40.381283 39.535843
1125 40.392227 39.840981
1126 40.424850 40.076019
1127 40.432117 39.641319
1128 40.479900 39.794704
1129 40.511364 39.746494
1130 40.358841 39.422947
1131 40.337990 39.645966
1132 40.301529 39.420403
1133 40.338196 39.617676
1134 40.460522 40.035027
1135 40.485474 39.707615
1136 40.485203 39.854336
1137 40.416042 39.803383
1138 40.462479 39.785225
1139 40.402809 39.971245
1140 40.505932 40.322811
1141 40.507717 39.655350
1142 40.411400 39.754681
1143 40.284733 39.588455
1144 40.342941 39.694672
1145 40.372261 39.576389
1146 40.576756 39.666443
1147 40.579685 39.904652
1148 40.359253 39.608665
1149 40.371147 39.413570
1150 40.436649 39.536232
1151 40.463135 39.963833
1152 40.489616 40.031631
1153 40.432259 39.951969
1154 40.511436 39.979687
1155 40.430931 39.630451
1156 40.452488 39.747173
1157 40.562775 40.380337
1158 40.472847 39.765800
1159 40.488167 39.697468
1160 40.494900 40.141106
1161 40.432598 39.739311
1162 40.451591 39.726501
1163 40.405762 39.619675
1164 40.384483 39.770905
1165 40.408836 39.742611
1166 40.442734 39.992306
1167 40.394329 39.610775
1168 40.493603 39.514900
1169 40.331509 39.693722
1170 40.322540 39.744461
1171 40.407345 40.125820
1172 40.417137 40.112587
1173 40.344223 39.536308
1174 40.510098 39.694836
1175 40.384354 39.511013
1176 40.372833 39.836594
1177 40.311241 39.634583
1178 40.382477 39.745964
1179 40.412781 39.847099
1180 40.389812 39.597218
1181 40.576595 39.988041
1182 40.496552 39.812248
1183 40.457848 39.580673
1184 40.393391 39.518860
1185 40.433102 39.617455
1186 40.376633 39.579475
1187 40.583210 40.557339
1188 40.498699 39.614052
1189 40.394936 39.589581
1190 40.434704 39.580017
1191 40.347633 39.601406
1192 40.276176 39.741463
1193 40.377785 40.142254
1194 40.369446 39.658661
1195 40.351711 39.570824
1196 40.386543 39.613274
1197 40.523743 40.351780
1198 40.451385 39.876171
1199 40.337227 39.680267
1200 40.396854 39.841129


In [139]:
l.recorder.plot_losses()



In [135]:
l.save(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse_nlayers_{len(layers_sizes)}")

In [6]:
l = l.load(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse_nlayers_{len(layers_sizes)}")

In [7]:
val_df = pd.DataFrame()
train_df = pd.DataFrame()

preds, targets = l.get_preds(fai.basic_data.DatasetType.Valid)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

val_df['pred'] = preds
val_df['target'] = targets
val_df['abs_diff'] = np.abs(preds - targets)
val_df['APE'] = np.abs(val_df.target - val_df.pred)/val_df.target * 100

preds, targets = l.get_preds(fai.basic_data.DatasetType.Train)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

train_df['pred'] = preds
train_df['target'] = targets
train_df['abs_diff'] = np.abs(preds - targets)
train_df['APE'] = np.abs(train_df.target - train_df.pred)/train_df.target * 100

In [105]:
train_df.describe()


Out[105]:
pred target abs_diff APE
count 90000.000000 90000.000000 90000.000000 90000.000000
mean 1.941154 2.031324 0.684319 61.384346
std 1.609805 1.925734 0.753625 107.120293
min -0.157797 0.028617 0.000011 0.000544
25% 0.742593 0.649040 0.180828 14.958478
50% 1.470519 1.396559 0.429848 32.670835
75% 2.753879 2.621620 0.925115 61.996639
max 10.633808 13.560771 8.974062 2388.007812

In [100]:
train_df.describe()


Out[100]:
pred target abs_diff APE
count 90000.000000 90000.000000 90000.000000 90000.000000
mean 1.955619 2.031322 0.686646 65.346382
std 1.562531 1.925733 0.743749 118.350204
min -0.145053 0.028617 0.000002 0.000185
25% 0.790863 0.649040 0.187947 15.021752
50% 1.513035 1.396559 0.444722 32.155708
75% 2.789342 2.621620 0.923823 63.015770
max 10.399451 13.560771 8.745758 2802.083008

In [134]:
val_df.describe()


Out[134]:
pred target abs_diff APE
count 10000.000000 10000.000000 10000.000000 10000.000000
mean 1.502686 2.019351 0.763961 42.495895
std 1.421925 1.923597 0.926997 36.887840
min -0.139572 0.031008 0.000083 0.007530
25% 0.433389 0.645243 0.150516 17.843959
50% 0.992827 1.387672 0.426106 35.945587
75% 2.151397 2.617450 1.029043 60.084542
max 8.528646 13.558847 7.275390 654.606445

In [101]:
val_df.describe()


Out[101]:
pred target abs_diff APE
count 10000.000000 10000.000000 10000.000000 10000.000000
mean 1.959154 2.019354 0.676217 65.818283
std 1.570999 1.923599 0.731138 122.390106
min -0.089540 0.031008 0.000007 0.000882
25% 0.777360 0.645243 0.177962 14.538821
50% 1.504867 1.387672 0.436823 31.890802
75% 2.792107 2.617450 0.914852 64.005240
max 10.399451 13.558847 6.138021 2274.407227

In [ ]:
sns.jointplot()