In [37]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [40, 30]
In [1]:
import fastai as fai
from fastai.basic_data import DataLoader
from data_loader import *
import numpy as np
from torch.utils.data import SubsetRandomSampler
from model import *
from model_bn import *
from torch import optim
import dill
import papermill as pm
from main import *
import seaborn as sns
import pandas as pd
In [2]:
optimizer = 'Adam'
num_workers=8
maxsize=100000
batch_size=2048
n_epochs=500
batch_norm = True
dataset='data/speedup_dataset.h5'
In [3]:
train_dl, val_dl = train_dev_split(dataset, batch_size, num_workers, maxsize)
db = fai.basic_data.DataBunch(train_dl, val_dl)
In [4]:
def criterion(inputs, targets):
eps = 1e-5
return torch.mean(torch.abs(targets - inputs)/(targets+eps)*100)
In [5]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]
layers_sizes = [300, 200, 120, 80, 30]
model = None
if batch_norm:
model = Model_BN(input_size, output_size, hidden_sizes=layers_sizes)
else:
model = Model(input_size, output_size)
#criterion = nn.MSELoss()
l = fai.Learner(db, model, loss_func=criterion)
if optimizer == 'SGD':
l.opt_func = optim.SGD
In [44]:
print(l.model)
Model_BN(
(hidden_layers): ModuleList(
(0): Linear(in_features=381, out_features=300, bias=False)
(1): Linear(in_features=300, out_features=200, bias=False)
(2): Linear(in_features=200, out_features=120, bias=False)
(3): Linear(in_features=120, out_features=80, bias=False)
(4): Linear(in_features=80, out_features=30, bias=False)
)
(batch_norm_layers): ModuleList(
(0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(predict): Linear(in_features=30, out_features=1, bias=True)
)
In [116]:
l.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [117]:
l.recorder.plot()
In [136]:
lr =1e-5
In [137]:
l.fit_one_cycle(1200, lr)
Total time: 2:32:36
epoch
train_loss
valid_loss
1
43.308823
42.511745
2
43.440548
42.512672
3
43.522625
42.832088
4
43.359398
42.508614
5
43.424362
42.823555
6
43.302296
42.570656
7
43.337223
42.466236
8
43.198784
42.481415
9
43.276703
42.472767
10
43.393902
42.904263
11
43.391769
42.667290
12
43.403667
42.497829
13
43.380348
42.411304
14
43.314499
42.548527
15
43.428509
42.706577
16
43.367168
42.650780
17
43.343029
42.677246
18
43.319000
42.489429
19
43.391106
42.656219
20
43.388210
42.591579
21
43.298534
42.557522
22
43.390308
42.564564
23
43.285198
42.421730
24
43.337765
42.498516
25
43.398766
42.607731
26
43.432446
42.581360
27
43.537910
42.634064
28
43.363823
42.551136
29
43.272110
42.537769
30
43.310291
42.701267
31
43.352707
42.672440
32
43.470226
42.658062
33
43.471638
43.215084
34
43.381989
42.497452
35
43.275143
42.465843
36
43.454556
42.537891
37
43.342133
42.553917
38
43.491207
42.599930
39
43.492729
42.788162
40
43.393791
42.519951
41
43.326050
42.508144
42
43.264702
42.469883
43
43.262543
42.562611
44
43.189724
42.455605
45
43.508636
43.351971
46
43.469368
42.524853
47
43.415024
42.548290
48
43.464169
42.469677
49
43.532379
42.680756
50
43.446045
42.569805
51
43.244213
42.418137
52
43.348724
42.588364
53
43.323513
42.531250
54
43.429993
42.579483
55
43.407803
42.584702
56
43.445045
42.670910
57
43.460434
42.590237
58
43.447296
42.722595
59
43.325001
42.600826
60
43.355171
42.485855
61
43.348545
42.517017
62
43.279091
42.603905
63
43.369083
42.579792
64
43.432575
42.719604
65
43.240456
42.461514
66
43.270180
42.428192
67
43.428951
42.719746
68
43.407288
42.508919
69
43.487976
42.704956
70
43.427967
42.735149
71
43.384842
42.591808
72
43.280643
42.450020
73
43.363865
42.864948
74
43.431782
42.713253
75
43.452579
42.648865
76
43.370678
42.467739
77
43.367920
42.591820
78
43.376579
42.467312
79
43.279800
42.487236
80
43.280773
42.534298
81
43.337933
42.719807
82
43.303379
42.480621
83
43.540298
42.723038
84
43.448555
42.610886
85
43.321148
42.382626
86
43.288475
42.521107
87
43.310978
42.427227
88
43.354904
42.763062
89
43.357090
42.701160
90
43.235104
42.497105
91
43.286869
42.386902
92
43.400986
42.518703
93
43.344879
42.422607
94
43.313339
42.456898
95
43.344193
42.705093
96
43.289135
42.512676
97
43.295864
42.409733
98
43.399303
42.475636
99
43.362911
42.493778
100
43.316734
42.766773
101
43.316734
42.588840
102
43.397621
42.586967
103
43.487553
42.623798
104
43.426090
42.680416
105
43.335175
42.450436
106
43.267071
42.332695
107
43.291157
42.436569
108
43.248577
42.464687
109
43.265816
42.339283
110
43.259209
42.500687
111
43.165733
42.683731
112
43.152603
42.397896
113
43.320751
42.639519
114
43.263523
42.415813
115
43.270466
42.813183
116
43.334484
42.662815
117
43.389759
42.488380
118
43.325508
42.603394
119
43.361469
42.493492
120
43.335346
42.578968
121
43.294811
42.548058
122
43.222107
42.366676
123
43.160793
42.383427
124
43.141220
42.593845
125
43.309528
42.506844
126
43.144165
42.516418
127
43.284153
42.507442
128
43.323742
42.441414
129
43.351765
42.560127
130
43.193291
42.463081
131
43.429253
42.546333
132
43.271992
42.470543
133
43.363823
42.749699
134
43.237473
42.513920
135
43.264339
42.445187
136
43.331261
42.561455
137
43.312698
42.553104
138
43.244049
42.574196
139
43.318817
42.570320
140
43.194126
42.338051
141
43.211468
42.325531
142
43.285694
42.543560
143
43.373501
42.735111
144
43.335075
42.575508
145
43.193237
42.311756
146
43.247868
42.399548
147
43.201267
42.263844
148
43.297466
42.382526
149
43.224606
42.407101
150
43.180817
42.357018
151
43.379089
42.760448
152
43.277565
42.502728
153
43.373203
42.504719
154
43.370510
42.435719
155
43.421436
42.638718
156
43.295113
42.404251
157
43.295193
42.432026
158
43.393440
42.474426
159
43.315281
42.686474
160
43.289707
42.407482
161
43.250385
42.288689
162
43.222820
42.413395
163
43.211552
42.366455
164
43.263912
42.659931
165
43.315254
42.561066
166
43.311863
42.646599
167
43.198475
42.334595
168
43.190556
42.534267
169
43.329521
42.330570
170
43.321499
42.354347
171
43.227497
42.327354
172
43.335384
42.516888
173
43.269951
42.355476
174
43.178883
42.432110
175
43.089993
42.195763
176
43.221603
42.362789
177
43.184181
42.476276
178
43.176987
42.400883
179
43.170002
42.517239
180
43.141499
42.437321
181
43.233356
42.316410
182
43.172886
42.436558
183
43.143227
42.500553
184
43.285915
42.745098
185
43.280193
42.496872
186
43.258728
42.520077
187
43.211422
42.226952
188
43.184296
42.244671
189
43.349567
42.674175
190
43.151520
42.397049
191
43.199123
42.401440
192
43.136684
42.312008
193
43.166206
42.420994
194
43.263863
42.250496
195
43.133865
42.286366
196
43.206516
42.311127
197
43.265869
42.431080
198
43.268402
42.437626
199
43.119148
42.320129
200
43.143574
42.650581
201
43.257111
42.446098
202
43.112968
42.259274
203
43.098816
42.234215
204
43.317780
43.193264
205
43.265362
42.595005
206
43.120068
42.211823
207
43.050522
42.187096
208
43.137661
42.296898
209
43.038986
42.228149
210
43.204395
42.334774
211
43.156063
42.152214
212
43.167034
42.443050
213
43.014675
42.193607
214
43.183807
42.403728
215
43.211624
42.190807
216
43.130669
42.245991
217
43.168030
42.448418
218
43.108006
42.431931
219
43.096741
42.375481
220
43.227245
42.572723
221
43.045704
42.201012
222
43.024002
42.356026
223
43.080677
42.380455
224
42.980579
42.152283
225
43.053883
42.308189
226
42.940228
42.161774
227
43.005127
42.188911
228
43.033890
42.173935
229
42.992813
42.585255
230
42.999443
42.258640
231
42.961220
42.297909
232
42.917561
42.176254
233
43.008018
42.334019
234
43.051231
42.351738
235
43.018929
42.188313
236
42.975731
42.186150
237
42.963436
42.153095
238
43.129246
42.566322
239
43.091927
42.312054
240
43.121372
42.086952
241
43.029343
42.135124
242
43.044067
42.219898
243
43.012306
42.122932
244
43.014561
42.322819
245
43.053043
42.117138
246
43.017601
42.223198
247
43.045921
42.410545
248
43.019180
42.228680
249
43.002071
42.257088
250
43.103962
42.381157
251
42.967606
42.110027
252
42.945602
42.112461
253
42.846420
41.988468
254
42.916729
42.042744
255
43.028671
42.331482
256
42.947941
42.170731
257
42.880531
42.169357
258
42.901554
42.307484
259
43.014194
42.091679
260
42.970135
42.197788
261
43.011520
42.168026
262
43.005985
42.154202
263
43.004566
42.135094
264
42.884888
41.997787
265
42.831566
42.002308
266
42.959160
42.288376
267
42.924625
42.021477
268
43.017998
42.208313
269
43.120888
42.410500
270
43.162666
42.276325
271
43.073608
41.938839
272
42.901661
41.855385
273
42.842232
42.272804
274
42.882057
41.952637
275
42.916054
42.177704
276
42.792236
42.238007
277
42.890148
42.084408
278
42.828568
42.205986
279
42.811569
42.049580
280
43.243748
42.488880
281
43.065220
42.080807
282
42.888756
42.077328
283
42.845760
41.826656
284
42.935131
42.148705
285
42.760319
42.120308
286
42.879009
42.045528
287
42.862892
41.905067
288
42.982159
42.151749
289
43.023952
42.133980
290
43.071560
42.426815
291
42.938454
42.044838
292
42.861984
42.103294
293
43.003902
42.481201
294
42.831276
42.051811
295
42.980816
42.079124
296
42.809906
41.972889
297
42.792458
42.026894
298
42.810532
42.067600
299
42.822411
42.184246
300
42.779461
41.771545
301
42.880264
42.175171
302
42.844345
41.856579
303
42.655869
41.949787
304
42.701820
41.751923
305
42.678684
41.779453
306
42.744370
41.835495
307
42.813496
42.149643
308
42.879421
42.384701
309
43.057240
42.097874
310
42.972980
41.883121
311
42.877312
42.190792
312
43.047871
42.169724
313
42.931076
41.951782
314
42.758377
41.773201
315
42.701591
41.921288
316
42.953884
42.152615
317
42.756279
41.882862
318
42.677769
41.713383
319
42.578537
41.687145
320
42.628929
41.960384
321
42.529114
41.883400
322
42.572681
41.885738
323
42.579865
41.658867
324
42.808060
42.204353
325
42.771034
42.056339
326
42.570560
41.774673
327
42.529305
41.784626
328
42.615219
41.629383
329
42.660172
42.022442
330
42.852463
42.097675
331
42.759567
42.097233
332
42.576832
41.618286
333
42.746090
42.166943
334
42.922039
42.154999
335
43.016827
42.046646
336
42.746376
41.685814
337
42.636288
41.800156
338
42.586994
42.034973
339
42.475388
41.628414
340
42.426167
41.605614
341
42.472683
41.705425
342
42.565437
42.146664
343
42.488110
41.637955
344
42.781971
42.115387
345
42.607430
41.648045
346
42.512070
41.692696
347
42.520882
41.668137
348
42.613682
41.942402
349
42.654659
42.079990
350
42.608181
41.661129
351
42.572063
41.775986
352
42.537857
41.571064
353
42.544552
41.509159
354
42.469330
41.781708
355
42.551334
41.602943
356
42.440598
41.464443
357
42.496616
41.676586
358
42.609161
41.565777
359
42.489223
41.349514
360
42.752827
41.785976
361
42.677479
41.707218
362
42.543865
41.716900
363
42.322128
41.387867
364
42.452339
41.645294
365
42.441360
41.609970
366
42.490810
41.658108
367
42.725384
42.023575
368
42.652695
41.678944
369
42.660336
41.903721
370
42.630302
41.843071
371
42.478783
41.507824
372
42.419155
41.640976
373
42.410950
41.892738
374
42.327682
41.737839
375
42.599144
41.745640
376
42.454029
41.552967
377
42.417725
41.545242
378
42.400539
41.380894
379
42.446865
41.659908
380
42.384380
41.816437
381
42.403580
41.594154
382
42.443386
41.938171
383
42.359291
41.371670
384
42.537037
41.883488
385
42.516636
41.500137
386
42.471733
41.739201
387
42.418522
41.567677
388
42.439098
41.529274
389
42.334743
41.361462
390
42.333385
41.571407
391
42.385750
41.863811
392
42.593742
41.712864
393
42.454536
41.454147
394
42.563129
41.970089
395
42.420856
41.415218
396
42.879295
42.796795
397
42.606812
41.544258
398
42.390259
41.528698
399
42.222370
41.487305
400
42.263580
41.282112
401
42.267460
41.468239
402
42.326324
41.465820
403
42.563320
42.024467
404
42.418571
41.707989
405
42.355164
41.242661
406
42.408871
42.003685
407
42.207825
41.221474
408
42.235348
41.318092
409
42.317959
41.438202
410
42.319637
41.463772
411
42.231670
41.405922
412
42.289207
41.522751
413
42.288872
41.409077
414
42.088970
41.223618
415
42.208500
41.662086
416
42.193970
41.502048
417
42.024624
41.080818
418
42.127243
41.360786
419
42.162781
41.379070
420
42.239967
41.522907
421
42.289818
41.674026
422
42.249401
41.432499
423
42.369846
41.651722
424
42.364357
41.472591
425
42.301571
41.310585
426
42.153515
41.232483
427
42.105740
41.359154
428
42.209255
41.805950
429
42.111828
41.088322
430
42.068062
41.191013
431
42.147125
41.601315
432
42.111267
41.370411
433
42.095078
41.381992
434
42.067917
41.487892
435
42.128036
41.622784
436
42.246876
41.394268
437
42.245373
41.395126
438
42.122765
41.258671
439
42.125824
41.771034
440
42.176102
41.325230
441
42.022976
40.988430
442
42.081821
41.394405
443
42.161446
41.286739
444
42.173492
41.373737
445
42.021744
41.160484
446
42.121262
41.319965
447
42.031612
41.067039
448
42.013744
41.168812
449
41.943787
41.299942
450
42.122696
41.293865
451
42.467323
41.651047
452
42.164436
41.164673
453
42.106731
41.083660
454
42.059650
40.991531
455
42.176949
41.426384
456
42.081562
41.259350
457
42.024689
41.449917
458
42.030643
41.020264
459
42.049995
41.241283
460
42.142960
41.205929
461
41.910553
41.114620
462
41.992405
41.124996
463
41.894688
40.983395
464
42.061298
41.389393
465
42.227573
41.479839
466
42.175461
41.941811
467
41.915447
40.878674
468
41.974064
41.639580
469
42.071255
41.349049
470
42.170204
41.482292
471
42.133923
41.060043
472
42.177132
41.067936
473
42.152084
41.551014
474
42.111397
41.212067
475
41.949821
41.019821
476
41.921036
41.098701
477
41.805679
41.107639
478
41.991096
41.479622
479
41.873230
41.066296
480
41.831280
41.123425
481
41.962620
41.033226
482
41.959923
41.240871
483
41.988945
41.180145
484
41.910683
41.048920
485
41.930115
41.171219
486
42.155830
41.618214
487
42.090412
41.340714
488
41.926365
41.035114
489
42.279736
41.808609
490
42.080589
41.016533
491
42.010029
41.527462
492
42.094784
41.202850
493
42.103680
41.707260
494
42.018597
40.870651
495
41.992836
41.589737
496
41.990936
41.330601
497
41.907230
41.012062
498
41.857666
41.258148
499
41.950981
41.270409
500
41.977772
41.467751
501
41.906109
41.173149
502
41.887047
41.408421
503
41.686737
40.788727
504
41.862286
41.406509
505
41.980446
41.493587
506
41.827488
40.919365
507
41.908943
41.143677
508
41.886524
41.042088
509
41.731739
40.883080
510
41.879910
40.888866
511
41.813515
41.156849
512
41.857796
41.087631
513
41.877563
41.263420
514
41.826733
40.969051
515
41.935905
41.169125
516
41.787067
41.119232
517
41.815952
41.275387
518
41.656502
40.708199
519
41.705406
40.862274
520
41.847084
41.199581
521
41.836750
40.818424
522
41.875713
41.300861
523
41.773254
40.775017
524
41.708744
40.929337
525
41.753094
40.971260
526
41.809929
41.016682
527
41.731800
40.812443
528
41.876511
40.776569
529
41.832520
41.218956
530
41.798237
41.098492
531
41.869129
40.866913
532
41.670387
40.917320
533
41.558941
40.732033
534
41.584198
40.888977
535
41.646076
40.699394
536
41.825211
41.314728
537
41.877888
41.209019
538
41.874569
40.779682
539
41.790394
41.036293
540
41.678162
40.679352
541
41.827297
41.073544
542
41.860817
41.381279
543
41.730244
40.923752
544
41.715946
40.956726
545
41.588318
40.659569
546
41.657391
41.284939
547
41.635639
40.834858
548
41.777176
40.736141
549
41.627342
40.674606
550
41.631546
40.893482
551
41.496197
40.710480
552
41.565506
41.486031
553
41.642769
41.107445
554
41.800079
41.078568
555
41.532654
40.545906
556
41.624531
41.028530
557
41.684696
40.934799
558
41.612522
40.689919
559
41.571518
40.780586
560
41.539387
40.757141
561
41.518322
40.659340
562
41.569038
40.936302
563
41.486897
40.640961
564
41.430187
40.707344
565
41.378338
40.546589
566
41.599056
41.007545
567
41.553341
40.736259
568
41.564583
41.023861
569
41.662045
40.648735
570
41.568188
40.632805
571
41.521824
40.756264
572
41.604736
40.992992
573
41.604275
40.914330
574
41.381992
40.543228
575
41.491383
41.032055
576
41.395073
40.678730
577
41.531811
40.694538
578
41.648514
41.222237
579
41.531082
40.667088
580
41.437744
40.539490
581
41.495621
40.520508
582
41.512562
40.566673
583
41.542316
40.656567
584
41.661018
41.384239
585
41.531498
40.707802
586
41.604176
40.570213
587
41.402355
40.556934
588
41.500256
40.995655
589
41.409679
40.516075
590
41.418301
40.655670
591
41.493027
40.445225
592
41.672611
41.460804
593
41.607914
40.939213
594
41.697895
41.253914
595
41.482330
40.809391
596
41.303638
40.486267
597
41.318302
40.648708
598
41.429344
40.570374
599
41.366116
40.649960
600
41.401955
40.708874
601
41.520546
40.689365
602
41.548271
41.571499
603
41.515591
40.744755
604
41.444080
40.479248
605
41.337673
40.466507
606
41.365715
40.729218
607
41.626705
41.067101
608
41.394535
40.419418
609
41.345085
40.508076
610
41.424660
40.940689
611
41.424084
40.622986
612
41.433434
40.528633
613
41.535995
40.617268
614
41.560452
40.488663
615
41.459953
40.465042
616
41.549667
40.667744
617
41.405312
40.801685
618
41.304848
40.441494
619
41.367802
40.485817
620
41.491177
40.628551
621
41.549614
40.734730
622
41.433884
40.462082
623
41.593071
41.021183
624
41.686382
40.952255
625
41.644787
40.702011
626
41.346592
40.492943
627
41.300320
40.398438
628
41.330936
40.805431
629
41.313736
40.456326
630
41.471794
40.817657
631
41.313923
40.473148
632
41.273590
40.498230
633
41.191666
40.406803
634
41.139400
40.737961
635
41.333096
40.849392
636
41.291962
40.756123
637
41.275272
40.473339
638
41.309368
40.837959
639
41.361656
40.640419
640
41.385620
40.486992
641
41.507359
41.046646
642
41.451199
40.665524
643
41.437210
40.470062
644
41.351738
40.841770
645
41.327591
40.784107
646
41.319489
40.499355
647
41.370430
40.459496
648
41.218079
40.390503
649
41.210152
40.357952
650
41.219578
40.524673
651
41.202190
40.282078
652
41.288521
40.614742
653
41.077396
40.189632
654
41.194698
40.652267
655
41.119354
40.231861
656
41.231789
40.615288
657
41.355175
40.359596
658
41.276402
40.235004
659
41.639481
41.093361
660
41.460621
40.685638
661
41.344440
40.374729
662
41.195580
40.403080
663
41.152054
40.261646
664
41.048531
40.121506
665
41.154869
40.312317
666
41.304161
40.715458
667
41.166290
40.214252
668
41.284626
40.363964
669
41.340397
40.522205
670
41.290802
40.421261
671
41.378082
40.558651
672
41.321747
40.596020
673
41.366146
40.747093
674
41.243782
40.486820
675
41.185677
40.374935
676
41.140579
40.521584
677
41.249638
40.657227
678
41.258457
40.267982
679
41.360931
40.932064
680
41.170448
40.265610
681
41.058727
40.357063
682
41.281151
41.072071
683
41.151405
40.766479
684
41.176552
40.171219
685
41.214455
40.509598
686
41.161293
40.298279
687
41.183250
40.214439
688
41.060959
40.520542
689
40.984051
40.378044
690
40.988049
40.232063
691
41.032921
40.665844
692
41.157627
40.345905
693
41.197033
40.318069
694
41.074337
40.127693
695
41.252502
40.383270
696
41.067722
40.366169
697
41.070461
40.564323
698
41.082527
40.214256
699
41.159801
40.227875
700
41.136448
40.270336
701
41.096451
40.226452
702
41.209103
40.380409
703
41.131569
40.243477
704
41.003620
40.051167
705
41.166714
40.572136
706
41.033661
40.303261
707
40.975845
39.971012
708
40.981819
40.063187
709
40.925880
40.112511
710
41.067966
40.683857
711
41.086628
40.367264
712
40.912136
40.082050
713
41.063393
40.407467
714
41.054035
40.265480
715
40.924576
40.178741
716
40.946564
40.182468
717
41.023506
40.353886
718
41.071487
40.420639
719
41.115040
40.309917
720
41.014431
40.474224
721
40.942276
40.043350
722
40.886715
40.085697
723
40.936646
40.060089
724
40.975372
40.077240
725
40.967655
40.656929
726
41.018806
40.343430
727
40.996471
40.275791
728
40.933498
40.224888
729
41.185730
41.394409
730
41.017967
39.964413
731
40.900627
40.621731
732
41.085785
40.420341
733
41.220982
40.448334
734
41.230083
40.103645
735
41.085861
40.142178
736
40.889782
40.201962
737
40.834663
40.403458
738
40.891838
40.693310
739
40.997715
40.297123
740
40.916607
40.238029
741
41.089497
40.268532
742
40.975601
40.186878
743
40.971169
40.369740
744
40.998337
39.976585
745
41.029079
40.195789
746
41.036694
40.374249
747
40.853886
39.979923
748
41.036224
41.102417
749
40.761005
40.015652
750
40.937458
40.373241
751
41.052063
40.220867
752
41.040680
40.737900
753
41.008465
40.819061
754
40.963085
40.209064
755
40.906162
40.123703
756
40.956936
40.376244
757
40.833447
40.136555
758
40.922260
40.071774
759
40.789394
40.114037
760
40.720764
40.011383
761
40.814255
40.178722
762
40.811016
40.074051
763
40.933746
40.486237
764
40.831131
40.016827
765
41.042549
41.079002
766
40.843761
39.904072
767
40.995380
40.416908
768
40.804173
40.005966
769
40.880569
40.051498
770
40.912930
40.050156
771
40.997143
40.335690
772
40.905266
40.051048
773
41.004200
40.618263
774
40.978661
40.312965
775
40.911121
40.455830
776
40.934414
39.908123
777
40.916431
40.018711
778
40.868862
40.035873
779
40.885670
40.418522
780
40.816441
39.980751
781
40.831200
40.423901
782
40.853802
40.103184
783
40.887348
40.341496
784
40.981075
40.406944
785
41.011086
40.304764
786
41.017017
40.588020
787
40.916386
40.085567
788
40.961414
40.277885
789
40.971256
40.065929
790
40.792004
39.989799
791
40.959461
40.285873
792
40.913250
40.565445
793
40.856689
40.040825
794
40.883919
40.298317
795
40.847378
39.931561
796
40.807690
39.921162
797
40.957062
40.206436
798
40.914886
40.175301
799
40.853085
39.775169
800
40.940441
40.161663
801
41.023335
40.231773
802
40.931759
40.045860
803
40.911732
39.979954
804
40.783710
39.930870
805
40.758572
39.909874
806
40.792576
40.040386
807
40.732674
39.860939
808
40.754860
39.982777
809
40.749760
39.906776
810
40.921444
40.970131
811
40.995865
40.865070
812
40.854610
40.027294
813
40.912228
40.486668
814
41.020950
40.134705
815
41.020638
40.232574
816
40.979534
40.663425
817
40.939857
40.470528
818
40.854759
40.057556
819
40.962372
40.234436
820
40.886402
40.519402
821
40.858433
40.075470
822
40.709923
39.929611
823
40.747318
40.040264
824
40.854713
40.244896
825
40.821587
39.987995
826
40.797852
39.919605
827
40.830460
40.380699
828
40.786392
39.852871
829
41.035145
40.107105
830
40.860668
39.845684
831
40.720524
39.894104
832
40.607357
39.679382
833
40.684887
39.846333
834
40.778606
40.394276
835
40.765900
39.939323
836
40.789360
40.130211
837
40.853348
40.483345
838
40.796124
39.814774
839
41.033669
41.100956
840
40.895020
40.389412
841
40.917355
40.565418
842
40.889256
40.017387
843
40.891983
40.241737
844
40.805092
40.372475
845
40.851681
40.248898
846
40.896931
40.892208
847
40.673424
39.977539
848
40.707779
40.055683
849
40.811199
40.334545
850
40.761631
39.748600
851
40.663139
39.922260
852
40.643574
39.794044
853
40.664204
40.014778
854
40.592197
39.708874
855
40.816669
40.455219
856
40.720055
39.844826
857
40.718552
39.940586
858
40.613964
39.866539
859
40.646267
39.843620
860
40.720890
40.224220
861
40.651451
39.945259
862
40.639641
39.675095
863
40.624245
39.782650
864
40.543926
39.642170
865
40.714779
40.279339
866
40.869873
40.452320
867
40.779217
39.903790
868
40.693844
40.072952
869
40.826004
40.404305
870
40.914467
40.210590
871
40.910683
40.033123
872
40.772545
39.731583
873
40.677059
39.944336
874
40.656075
39.843483
875
40.701370
39.970848
876
40.755684
40.430878
877
40.647034
39.719284
878
40.667519
40.173336
879
40.951550
40.459000
880
40.839649
39.846504
881
40.720222
39.831676
882
40.675961
39.890724
883
40.711239
39.744167
884
40.802536
40.458336
885
40.806564
40.309338
886
40.696140
39.750095
887
40.701099
39.775692
888
40.797997
40.449238
889
40.709328
39.882046
890
40.628433
40.026527
891
40.640503
39.730396
892
40.561245
39.975475
893
40.685043
40.187569
894
40.755295
39.989983
895
40.766693
40.635258
896
40.877121
39.987415
897
40.668056
39.940529
898
40.667946
39.888474
899
40.719513
39.822861
900
40.809032
39.923912
901
40.733295
40.116039
902
40.621716
39.949795
903
40.505562
39.725430
904
40.517445
39.866833
905
40.645538
39.874489
906
40.560749
39.623547
907
40.487881
40.043427
908
40.598515
40.013248
909
40.674671
39.987736
910
40.744576
39.776794
911
40.644955
39.774708
912
40.619812
40.406536
913
40.616779
39.846989
914
40.616253
39.874649
915
40.611301
40.175484
916
40.653305
40.090237
917
40.657349
39.597282
918
40.546528
39.802139
919
40.499569
39.872402
920
40.575958
39.874870
921
40.696377
40.103294
922
40.649117
39.918270
923
40.522972
39.600502
924
40.577835
40.112892
925
40.633358
39.994263
926
40.644714
40.010845
927
40.735760
40.165901
928
40.720463
40.255413
929
40.792141
40.219570
930
40.683563
40.170044
931
40.598984
39.904247
932
40.555508
39.931896
933
40.514114
39.786411
934
40.450920
39.605419
935
40.534355
39.626324
936
40.578091
39.945599
937
40.661781
39.832367
938
40.558086
40.150303
939
40.729145
40.869461
940
40.639793
39.808434
941
40.489254
39.636040
942
40.589981
40.369347
943
40.559216
39.810024
944
40.680058
39.892796
945
40.665073
39.821957
946
40.636124
39.763386
947
40.626278
39.673855
948
40.596588
40.002537
949
40.498104
39.673767
950
40.485558
40.115692
951
40.430492
39.663311
952
40.580757
39.736725
953
40.636387
40.064297
954
40.568398
39.910618
955
40.525951
39.909176
956
40.726250
39.986588
957
40.698860
40.060364
958
40.779949
40.155952
959
40.774429
39.980026
960
40.569839
39.726318
961
40.557690
39.821682
962
40.569077
39.892704
963
40.497444
39.667789
964
40.440483
39.623306
965
40.486912
39.751240
966
40.524788
39.752502
967
40.465141
39.677811
968
40.475712
39.677692
969
40.541534
39.843723
970
40.567284
39.868816
971
40.564899
39.756290
972
40.536880
39.690598
973
40.383530
39.671398
974
40.388718
39.674389
975
40.453445
39.920189
976
40.500866
39.818985
977
40.500160
40.493385
978
40.463108
39.753830
979
40.390816
39.838242
980
40.498051
40.046543
981
40.560879
40.093803
982
40.514053
40.029129
983
40.593937
40.488987
984
40.501015
39.599293
985
40.445526
39.805458
986
40.511055
39.800167
987
40.518280
39.735714
988
40.412918
39.525146
989
40.602623
40.099434
990
40.613029
40.253044
991
40.564087
39.914024
992
40.535580
39.601521
993
40.572948
40.043877
994
40.684002
39.950649
995
40.539352
39.803951
996
40.483269
39.676365
997
40.624840
39.573715
998
40.630455
40.416370
999
40.677254
40.898376
1000
40.486931
39.687836
1001
40.527069
39.633186
1002
40.474026
39.606518
1003
40.500767
39.563538
1004
40.485863
40.457455
1005
40.459488
39.804401
1006
40.507988
39.729321
1007
40.538780
40.188988
1008
40.561661
39.704449
1009
40.533539
40.208008
1010
40.476673
39.585136
1011
40.497005
40.391151
1012
40.471378
39.857773
1013
40.729794
40.056480
1014
40.483036
39.558571
1015
40.507442
39.715366
1016
40.643284
40.431927
1017
40.542389
40.362988
1018
40.531849
40.393620
1019
40.623852
40.329178
1020
40.538666
39.800465
1021
40.497059
39.715302
1022
40.617298
39.965569
1023
40.517368
39.815086
1024
40.464912
39.873730
1025
40.424934
39.619011
1026
40.442238
40.082062
1027
40.506176
39.910339
1028
40.538811
39.917976
1029
40.452213
39.594452
1030
40.449898
39.832100
1031
40.538639
39.809990
1032
40.468521
39.660194
1033
40.581989
40.132023
1034
40.480812
39.645756
1035
40.521618
39.545223
1036
40.509872
39.884480
1037
40.571037
39.678276
1038
40.575821
40.458538
1039
40.587395
39.857899
1040
40.403427
39.600536
1041
40.424961
40.102814
1042
40.327824
39.776329
1043
40.405025
39.663013
1044
40.545185
39.715286
1045
40.433685
40.063530
1046
40.493069
39.633972
1047
40.522240
39.692543
1048
40.499191
39.971050
1049
40.565495
40.043434
1050
40.490971
39.503593
1051
40.394646
39.452713
1052
40.505394
39.930416
1053
40.457630
39.584251
1054
40.524063
40.906895
1055
40.381950
39.533649
1056
40.447968
39.767536
1057
40.485237
40.004574
1058
40.591839
40.087379
1059
40.569870
40.261402
1060
40.489723
39.738266
1061
40.394268
39.961235
1062
40.308933
39.570168
1063
40.439533
40.153099
1064
40.386005
40.076324
1065
40.464310
39.864174
1066
40.455135
39.433292
1067
40.483479
39.604038
1068
40.376648
39.542095
1069
40.354000
39.770676
1070
40.407146
39.511738
1071
40.490440
39.923729
1072
40.420464
39.641220
1073
40.410698
39.760014
1074
40.379402
39.490726
1075
40.686226
40.676811
1076
40.502636
39.837307
1077
40.486832
39.803295
1078
40.520935
39.821896
1079
40.576923
39.577606
1080
40.440536
40.169582
1081
40.517105
40.211643
1082
40.534313
39.955830
1083
40.552937
39.955944
1084
40.496479
39.687641
1085
40.425030
39.579224
1086
40.378483
39.725609
1087
40.428391
39.508148
1088
40.394848
39.820133
1089
40.409531
39.744545
1090
40.401943
39.766262
1091
40.380562
39.655956
1092
40.461456
39.853828
1093
40.645985
40.137665
1094
40.498913
40.011494
1095
40.472145
40.146420
1096
40.364918
39.754036
1097
40.545738
40.523785
1098
40.641361
40.056324
1099
40.392792
39.549179
1100
40.432205
39.489304
1101
40.433895
39.814709
1102
40.520691
39.927532
1103
40.701027
40.244080
1104
40.543392
39.894527
1105
40.545578
39.996243
1106
40.487587
39.714054
1107
40.467247
40.363487
1108
40.452919
40.516113
1109
40.510250
39.931614
1110
40.523323
39.992432
1111
40.542763
40.571030
1112
40.423470
40.211361
1113
40.438393
39.843109
1114
40.543892
39.796761
1115
40.478603
39.722286
1116
40.475811
40.108189
1117
40.583199
40.107273
1118
40.468925
40.136986
1119
40.609550
40.372501
1120
40.371319
39.565388
1121
40.480164
40.028801
1122
40.382301
39.773926
1123
40.428699
39.981255
1124
40.381283
39.535843
1125
40.392227
39.840981
1126
40.424850
40.076019
1127
40.432117
39.641319
1128
40.479900
39.794704
1129
40.511364
39.746494
1130
40.358841
39.422947
1131
40.337990
39.645966
1132
40.301529
39.420403
1133
40.338196
39.617676
1134
40.460522
40.035027
1135
40.485474
39.707615
1136
40.485203
39.854336
1137
40.416042
39.803383
1138
40.462479
39.785225
1139
40.402809
39.971245
1140
40.505932
40.322811
1141
40.507717
39.655350
1142
40.411400
39.754681
1143
40.284733
39.588455
1144
40.342941
39.694672
1145
40.372261
39.576389
1146
40.576756
39.666443
1147
40.579685
39.904652
1148
40.359253
39.608665
1149
40.371147
39.413570
1150
40.436649
39.536232
1151
40.463135
39.963833
1152
40.489616
40.031631
1153
40.432259
39.951969
1154
40.511436
39.979687
1155
40.430931
39.630451
1156
40.452488
39.747173
1157
40.562775
40.380337
1158
40.472847
39.765800
1159
40.488167
39.697468
1160
40.494900
40.141106
1161
40.432598
39.739311
1162
40.451591
39.726501
1163
40.405762
39.619675
1164
40.384483
39.770905
1165
40.408836
39.742611
1166
40.442734
39.992306
1167
40.394329
39.610775
1168
40.493603
39.514900
1169
40.331509
39.693722
1170
40.322540
39.744461
1171
40.407345
40.125820
1172
40.417137
40.112587
1173
40.344223
39.536308
1174
40.510098
39.694836
1175
40.384354
39.511013
1176
40.372833
39.836594
1177
40.311241
39.634583
1178
40.382477
39.745964
1179
40.412781
39.847099
1180
40.389812
39.597218
1181
40.576595
39.988041
1182
40.496552
39.812248
1183
40.457848
39.580673
1184
40.393391
39.518860
1185
40.433102
39.617455
1186
40.376633
39.579475
1187
40.583210
40.557339
1188
40.498699
39.614052
1189
40.394936
39.589581
1190
40.434704
39.580017
1191
40.347633
39.601406
1192
40.276176
39.741463
1193
40.377785
40.142254
1194
40.369446
39.658661
1195
40.351711
39.570824
1196
40.386543
39.613274
1197
40.523743
40.351780
1198
40.451385
39.876171
1199
40.337227
39.680267
1200
40.396854
39.841129
In [139]:
l.recorder.plot_losses()
In [135]:
l.save(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse_nlayers_{len(layers_sizes)}")
In [6]:
l = l.load(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse_nlayers_{len(layers_sizes)}")
In [7]:
val_df = pd.DataFrame()
train_df = pd.DataFrame()
preds, targets = l.get_preds(fai.basic_data.DatasetType.Valid)
preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()
val_df['pred'] = preds
val_df['target'] = targets
val_df['abs_diff'] = np.abs(preds - targets)
val_df['APE'] = np.abs(val_df.target - val_df.pred)/val_df.target * 100
preds, targets = l.get_preds(fai.basic_data.DatasetType.Train)
preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()
train_df['pred'] = preds
train_df['target'] = targets
train_df['abs_diff'] = np.abs(preds - targets)
train_df['APE'] = np.abs(train_df.target - train_df.pred)/train_df.target * 100
In [105]:
train_df.describe()
Out[105]:
pred
target
abs_diff
APE
count
90000.000000
90000.000000
90000.000000
90000.000000
mean
1.941154
2.031324
0.684319
61.384346
std
1.609805
1.925734
0.753625
107.120293
min
-0.157797
0.028617
0.000011
0.000544
25%
0.742593
0.649040
0.180828
14.958478
50%
1.470519
1.396559
0.429848
32.670835
75%
2.753879
2.621620
0.925115
61.996639
max
10.633808
13.560771
8.974062
2388.007812
In [100]:
train_df.describe()
Out[100]:
pred
target
abs_diff
APE
count
90000.000000
90000.000000
90000.000000
90000.000000
mean
1.955619
2.031322
0.686646
65.346382
std
1.562531
1.925733
0.743749
118.350204
min
-0.145053
0.028617
0.000002
0.000185
25%
0.790863
0.649040
0.187947
15.021752
50%
1.513035
1.396559
0.444722
32.155708
75%
2.789342
2.621620
0.923823
63.015770
max
10.399451
13.560771
8.745758
2802.083008
In [134]:
val_df.describe()
Out[134]:
pred
target
abs_diff
APE
count
10000.000000
10000.000000
10000.000000
10000.000000
mean
1.502686
2.019351
0.763961
42.495895
std
1.421925
1.923597
0.926997
36.887840
min
-0.139572
0.031008
0.000083
0.007530
25%
0.433389
0.645243
0.150516
17.843959
50%
0.992827
1.387672
0.426106
35.945587
75%
2.151397
2.617450
1.029043
60.084542
max
8.528646
13.558847
7.275390
654.606445
In [101]:
val_df.describe()
Out[101]:
pred
target
abs_diff
APE
count
10000.000000
10000.000000
10000.000000
10000.000000
mean
1.959154
2.019354
0.676217
65.818283
std
1.570999
1.923599
0.731138
122.390106
min
-0.089540
0.031008
0.000007
0.000882
25%
0.777360
0.645243
0.177962
14.538821
50%
1.504867
1.387672
0.436823
31.890802
75%
2.792107
2.617450
0.914852
64.005240
max
10.399451
13.558847
6.138021
2274.407227
In [ ]:
sns.jointplot()
Content source: rbaghdadi/COLi
Similar notebooks: