In [4]:
from os import environ
environ['optimizer'] = 'Adam'
environ['num_workers']= '2'
environ['batch_size']= str(2048)
environ['n_epochs']= '1000'
environ['batch_norm']= 'True'
environ['loss_func']='MAPE'
environ['layers'] = '600 350 200 180'
environ['dropouts'] = '0.1 '* 4
environ['log'] = 'False'
environ['weight_decay'] = '0.01'
environ['cuda_device'] ='cuda:4'
environ['dataset'] = 'data/speedup_dataset2.pkl'
%run utils.ipynb
In [5]:
train_dl, val_dl, test_dl = train_dev_split(dataset, batch_size, num_workers, log=log)
db = fai.basic_data.DataBunch(train_dl, val_dl, test_dl, device=device)
In [6]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]
model = None
if batch_norm:
model = Model_BN(input_size, output_size, hidden_sizes=layers_sizes, drops=drops)
else:
model = Model(input_size, output_size)
if loss_func == 'MSE':
criterion = nn.MSELoss()
elif loss_func == 'MAPE':
criterion = mape_criterion
elif loss_func == 'SMAPE':
criterion = smape_criterion
l = fai.Learner(db, model, loss_func=criterion, metrics=[mape_criterion, rmse_criterion])
if optimizer == 'SGD':
l.opt_func = optim.SGD
In [4]:
l = l.load(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")
In [125]:
l.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [126]:
l.recorder.plot()
In [7]:
lr = 1e-03
In [8]:
l.fit_one_cycle(600, lr)
Total time: 19:09
epoch
train_loss
valid_loss
mape_criterion
rmse_criterion
1
102.462837
89.915024
89.915024
2.250221
2
95.942848
89.159119
89.159119
2.241213
3
93.677299
88.585869
88.585869
2.241689
4
92.049332
88.617554
88.617554
2.237574
5
90.969185
88.596176
88.596176
2.234123
6
91.528877
88.353172
88.353172
2.234489
7
89.688362
87.845184
87.845184
2.229401
8
88.902145
86.188789
86.188789
2.225424
9
87.189957
84.940178
84.940178
2.218786
10
86.617989
84.251053
84.251053
2.217868
11
86.144867
83.626022
83.626022
2.210706
12
84.905479
82.636887
82.636887
2.206949
13
83.870193
81.844536
81.844536
2.205022
14
83.264206
81.042114
81.042114
2.195430
15
82.299522
80.251030
80.251030
2.184253
16
81.930260
79.457634
79.457634
2.179665
17
80.887802
78.319626
78.319626
2.164439
18
80.192841
77.398598
77.398598
2.154759
19
79.316902
76.255775
76.255775
2.138506
20
78.167198
74.610641
74.610641
2.110413
21
77.248169
73.121490
73.121490
2.083524
22
76.619637
71.899475
71.899475
2.071931
23
75.689087
70.554329
70.554329
2.049546
24
75.462486
69.788422
69.788422
2.034508
25
74.128059
69.134758
69.134758
2.027010
26
74.013397
69.124947
69.124947
2.017795
27
73.529907
68.545235
68.545235
2.015644
28
72.503899
67.715485
67.715485
1.993982
29
72.018829
67.323875
67.323875
1.992565
30
71.861832
68.957306
68.957306
2.016355
31
71.359268
66.459358
66.459358
1.967845
32
70.976089
66.587563
66.587563
1.964452
33
70.218262
65.609566
65.609566
1.942873
34
70.110603
66.536560
66.536560
1.964318
35
69.803085
66.303574
66.303574
1.955917
36
69.532188
66.222702
66.222702
1.968754
37
69.123528
65.792366
65.792366
1.954502
38
68.652756
64.899185
64.899185
1.923390
39
68.343796
64.488197
64.488197
1.917168
40
67.861549
64.497505
64.497505
1.918128
41
67.511017
64.443481
64.443481
1.920465
42
67.215637
63.580688
63.580688
1.912533
43
66.639343
64.007141
64.007141
1.894527
44
65.877548
63.246277
63.246277
1.880019
45
65.822037
63.621712
63.621712
1.931469
46
65.259453
61.078979
61.078979
1.871760
47
64.938972
61.761864
61.761864
1.915089
48
64.027245
59.030312
59.030312
1.820233
49
63.683167
59.898449
59.898449
1.863873
50
63.381924
60.149769
60.149769
1.868998
51
61.532448
57.491699
57.491699
1.811738
52
61.413368
58.958481
58.958481
1.869629
53
60.256184
57.546345
57.546345
1.782692
54
61.000992
59.549824
59.549824
1.852500
55
59.750122
56.763351
56.763351
1.815845
56
58.880177
57.024956
57.024956
1.747679
57
57.977936
57.685349
57.685349
1.818881
58
57.351383
55.792374
55.792374
1.740868
59
61.805122
62.690342
62.690342
1.932269
60
62.743099
60.701401
60.701401
1.881479
61
62.275963
60.291924
60.291924
1.869433
62
60.954506
57.321407
57.321407
1.774410
63
57.536522
55.480381
55.480381
1.769087
64
59.621834
60.492695
60.492695
1.861783
65
58.414085
56.202839
56.202839
1.778402
66
58.000450
55.948418
55.948418
1.778747
67
57.790462
61.868870
61.868870
1.865341
68
57.708767
58.825874
58.825874
1.807891
69
54.291729
55.824299
55.824299
1.769154
70
53.209080
58.772564
58.772564
1.764153
71
52.375916
54.450249
54.450249
1.699080
72
51.397118
54.344074
54.344074
1.719169
73
49.561699
53.976311
53.976311
1.633079
74
49.862606
56.884064
56.884064
1.720541
75
49.479145
62.282661
62.282661
1.823001
76
48.908146
56.493301
56.493301
1.732743
77
48.634907
56.730099
56.730099
1.768547
78
48.033245
57.653049
57.653049
1.701744
79
46.540295
50.654930
50.654930
1.589988
80
47.030609
55.185963
55.185963
1.620118
81
46.125957
53.473927
53.473927
1.603792
82
44.691105
48.369156
48.369156
1.531184
83
44.350109
55.723076
55.723076
1.622825
84
45.898518
62.406368
62.406368
1.709704
85
45.058266
48.468037
48.468037
1.504565
86
43.726620
48.691936
48.691936
1.493813
87
44.156380
49.913189
49.913189
1.574697
88
42.925648
49.838276
49.838276
1.538066
89
42.416237
50.906349
50.906349
1.544893
90
41.707027
52.795776
52.795776
1.637186
91
41.388653
55.610207
55.610207
1.586232
92
42.173203
48.229458
48.229458
1.549769
93
40.100197
48.705097
48.705097
1.440520
94
41.953983
54.809658
54.809658
1.632107
95
41.144203
56.031155
56.031155
1.438029
96
42.218178
48.634567
48.634567
1.523352
97
41.702614
53.788464
53.788464
1.548024
98
40.408802
49.433216
49.433216
1.472887
99
39.715084
48.257210
48.257210
1.435997
100
44.082287
52.253365
52.253365
1.596702
101
42.178127
51.293236
51.293236
1.543357
102
40.538998
45.042782
45.042782
1.406151
103
38.412235
42.888336
42.888336
1.261014
104
37.880253
45.432343
45.432343
1.371741
105
39.726032
49.939201
49.939201
1.447204
106
41.930012
48.238358
48.238358
1.397959
107
39.206535
43.298393
43.298393
1.411555
108
38.627396
53.890980
53.890980
1.542061
109
38.115311
49.441555
49.441555
1.538206
110
37.774334
41.115311
41.115311
1.280305
111
37.242725
50.177139
50.177139
1.455027
112
37.662914
49.813534
49.813534
1.535243
113
37.925167
43.111256
43.111256
1.296932
114
37.597553
47.469551
47.469551
1.489701
115
36.807793
45.197067
45.197067
1.384499
116
37.096867
47.770855
47.770855
1.537368
117
35.754456
47.517849
47.517849
1.427904
118
37.579514
49.060925
49.060925
1.559251
119
36.904114
49.449589
49.449589
1.460099
120
36.047928
41.096149
41.096149
1.327718
121
35.932812
48.796745
48.796745
1.249137
122
36.227360
37.428486
37.428486
1.140873
123
36.399853
51.606255
51.606255
1.479480
124
36.774441
50.831142
50.831142
1.331906
125
35.642670
41.459324
41.459324
1.196531
126
36.321976
39.811092
39.811092
1.238446
127
34.772648
31.381090
31.381090
0.985488
128
34.303116
40.470329
40.470329
1.207205
129
34.631145
40.581608
40.581608
1.217004
130
34.646908
41.931065
41.931065
1.269815
131
34.482819
43.965099
43.965099
1.370830
132
33.739220
34.368599
34.368599
1.106400
133
34.495720
40.929226
40.929226
1.324431
134
36.684319
50.006027
50.006027
1.643450
135
34.724823
42.033699
42.033699
1.208778
136
34.700245
38.832420
38.832420
1.186832
137
34.516644
42.992619
42.992619
1.293013
138
33.454983
33.238853
33.238853
1.083079
139
33.991093
35.447414
35.447414
1.103924
140
33.476212
39.205551
39.205551
1.126302
141
33.460987
40.208801
40.208801
1.260162
142
34.475582
39.001701
39.001701
1.181114
143
33.446178
36.023533
36.023533
1.117951
144
34.918251
49.728058
49.728058
1.525955
145
33.617649
37.027851
37.027851
1.177658
146
32.607128
40.295261
40.295261
1.218831
147
34.180916
45.603886
45.603886
1.405544
148
34.175728
43.179070
43.179070
1.211370
149
34.803402
38.488731
38.488731
1.199201
150
33.536026
36.844143
36.844143
1.105938
151
32.863754
41.294613
41.294613
1.243217
152
33.102325
38.399017
38.399017
1.187201
153
33.886322
39.514030
39.514030
1.214723
154
32.800175
34.625404
34.625404
1.080302
155
31.884401
38.613480
38.613480
1.294129
156
33.187286
37.232204
37.232204
1.249735
157
32.381664
37.612865
37.612865
1.220467
158
32.255428
34.255901
34.255901
1.100269
159
31.848295
37.088108
37.088108
1.198890
160
30.968899
31.360840
31.360840
1.011190
161
31.737398
37.021214
37.021214
1.119648
162
35.513985
51.170986
51.170986
1.619334
163
32.419483
32.118145
32.118145
1.042945
164
32.267788
36.559731
36.559731
1.122198
165
32.265480
38.774670
38.774670
1.057342
166
31.750093
31.954960
31.954960
0.993154
167
32.550919
39.066200
39.066200
1.232327
168
31.056742
33.325405
33.325405
1.041935
169
31.377750
33.399265
33.399265
0.989797
170
33.871483
42.348354
42.348354
1.332318
171
31.340851
29.980474
29.980474
0.965288
172
31.399803
35.302532
35.302532
1.103125
173
33.367283
33.063938
33.063938
0.965582
174
31.520449
31.007048
31.007048
0.968705
175
33.009102
41.270267
41.270267
1.149267
176
32.392178
28.921974
28.921974
0.916448
177
31.051737
31.076653
31.076653
0.947544
178
30.386816
28.703388
28.703388
0.921619
179
30.807390
28.721554
28.721554
0.929247
180
30.424442
30.672453
30.672453
0.941073
181
30.349003
32.410038
32.410038
1.036487
182
30.194866
28.087044
28.087044
0.950450
183
30.263041
27.578218
27.578218
0.897694
184
29.797739
29.888056
29.888056
0.948301
185
29.935583
30.774916
30.774916
0.985125
186
29.459436
29.052956
29.052956
0.981487
187
33.849144
35.500896
35.500896
1.212891
188
32.666126
36.224243
36.224243
1.178105
189
30.029305
28.543825
28.543825
0.905285
190
32.766960
38.235004
38.235004
1.261275
191
31.579334
31.061453
31.061453
0.995024
192
31.259674
33.387699
33.387699
1.076000
193
31.128372
30.868631
30.868631
0.964490
194
30.971384
32.275379
32.275379
0.961773
195
30.854416
31.670696
31.670696
0.990046
196
30.184906
30.459307
30.459307
0.984591
197
30.470558
33.204475
33.204475
0.970237
198
30.040091
33.328411
33.328411
0.996078
199
29.553257
31.312305
31.312305
0.948989
200
29.600702
29.081434
29.081434
0.956585
201
29.032640
30.800125
30.800125
0.927248
202
29.429724
31.947538
31.947538
0.993270
203
29.506773
28.804781
28.804781
0.914111
204
29.326370
30.516644
30.516644
0.910598
205
29.402994
36.317833
36.317833
0.990869
206
29.002129
30.929520
30.929520
0.870361
207
30.279745
33.261204
33.261204
0.920278
208
29.205072
27.733316
27.733316
0.907874
209
28.521025
35.237633
35.237633
1.096733
210
28.497412
28.038363
28.038363
0.943656
211
31.133661
32.748589
32.748589
0.920423
212
30.033360
27.724091
27.724091
0.908205
213
28.693943
28.181000
28.181000
0.942253
214
28.000706
30.905306
30.905306
1.051152
215
29.153500
30.739197
30.739197
0.928864
216
28.129633
29.070465
29.070465
0.872321
217
28.315338
27.875219
27.875219
0.889311
218
27.600582
30.403532
30.403532
0.907008
219
27.180147
28.719963
28.719963
0.912207
220
27.602861
30.189648
30.189648
0.916323
221
27.851830
30.176502
30.176502
0.967201
222
28.132318
34.883808
34.883808
0.918503
223
27.547482
28.726351
28.726351
0.875762
224
27.841446
31.372299
31.372299
0.975907
225
27.053761
28.012924
28.012924
0.883935
226
27.428524
28.919622
28.919622
0.886396
227
27.623837
29.478462
29.478462
0.924545
228
26.549494
28.356609
28.356609
0.888520
229
26.166252
29.063196
29.063196
0.838485
230
26.995083
32.883877
32.883877
0.881544
231
26.979536
26.896238
26.896238
0.854347
232
26.347513
33.861176
33.861176
0.857156
233
26.862850
30.716005
30.716005
0.930571
234
26.682184
28.093563
28.093563
0.901258
235
26.211510
27.587170
27.587170
0.870875
236
26.153074
27.288422
27.288422
0.881822
237
26.078346
27.790150
27.790150
0.921486
238
26.404404
29.168247
29.168247
0.919915
239
26.428675
26.849258
26.849258
0.873241
240
26.542526
34.971981
34.971981
0.889436
241
25.921844
28.955263
28.955263
0.849940
242
25.856985
27.943659
27.943659
0.938976
243
25.410679
32.169914
32.169914
0.888684
244
25.466766
28.043381
28.043381
0.904460
245
25.664127
26.928066
26.928066
0.865609
246
25.406288
27.216732
27.216732
0.885839
247
25.136805
30.381466
30.381466
0.961977
248
25.960550
27.826349
27.826349
0.922123
249
25.955774
32.337696
32.337696
0.920436
250
25.592505
26.423412
26.423412
0.889608
251
24.991760
27.597250
27.597250
0.861189
252
25.949627
34.347752
34.347752
0.913615
253
26.009987
28.814590
28.814590
0.915201
254
25.342943
31.485123
31.485123
0.875107
255
24.925245
31.605396
31.605396
0.907865
256
25.129282
35.692726
35.692726
0.936480
257
25.764706
36.718327
36.718327
0.861335
258
25.515604
28.572878
28.572878
0.872104
259
24.672745
29.852816
29.852816
0.884588
260
24.553434
27.237350
27.237350
0.853266
261
24.927618
26.539103
26.539103
0.848360
262
24.688866
27.940100
27.940100
0.869520
263
24.790091
34.699253
34.699253
0.846236
264
24.550695
29.528303
29.528303
0.847367
265
24.272127
26.570253
26.570253
0.889928
266
24.099630
31.263144
31.263144
0.953881
267
24.633497
27.708530
27.708530
0.903770
268
24.316227
33.024181
33.024181
0.872730
269
24.144169
31.144737
31.144737
0.916001
270
24.194763
32.720905
32.720905
0.896790
271
24.339458
27.281113
27.281113
0.836668
272
24.091272
29.804873
29.804873
0.866757
273
24.092991
35.216183
35.216183
0.930799
274
23.619629
33.071480
33.071480
0.855799
275
23.990950
33.026218
33.026218
0.902117
276
24.407785
36.179108
36.179108
0.863402
277
23.897045
29.479143
29.479143
0.845802
278
23.637547
27.101412
27.101412
0.877575
279
23.636490
31.237318
31.237318
0.885830
280
23.524809
28.226509
28.226509
0.888378
281
23.535900
29.908449
29.908449
0.881832
282
23.579260
31.401525
31.401525
0.923868
283
23.483599
28.870726
28.870726
0.868786
284
23.333834
29.210741
29.210741
0.859095
285
23.539259
31.370253
31.370253
0.841539
286
23.517559
29.477880
29.477880
0.854113
287
23.327291
27.406416
27.406416
0.864116
288
23.278925
28.188749
28.188749
0.864451
289
23.204361
26.725651
26.725651
0.879889
290
23.237856
28.023657
28.023657
0.877536
291
22.992851
30.073887
30.073887
0.846566
292
22.998302
27.207613
27.207613
0.845029
293
22.701468
30.449644
30.449644
0.867035
294
23.232189
31.119488
31.119488
0.847104
295
22.928455
27.828276
27.828276
0.826710
296
22.889984
26.132402
26.132402
0.837898
297
22.581001
28.343269
28.343269
0.870220
298
22.613939
30.899712
30.899712
0.882964
299
22.557434
32.113148
32.113148
0.849321
300
22.505795
30.226730
30.226730
0.855375
301
22.548832
35.589844
35.589844
0.873132
302
22.543770
28.692303
28.692303
0.842034
303
22.701599
27.767765
27.767765
0.904211
304
22.400715
27.173031
27.173031
0.860667
305
22.197437
27.604269
27.604269
0.851240
306
22.338448
26.485975
26.485975
0.843565
307
22.221752
31.686926
31.686926
0.855039
308
22.269285
29.615650
29.615650
0.872594
309
22.107435
28.479919
28.479919
0.842593
310
22.405035
32.741665
32.741665
0.900984
311
22.369503
29.492785
29.492785
0.888509
312
22.380257
28.602345
28.602345
0.861227
313
22.052174
31.440632
31.440632
0.846058
314
21.847219
33.097412
33.097412
0.862574
315
22.089762
30.850168
30.850168
0.870820
316
22.143324
27.808676
27.808676
0.886002
317
21.871441
28.382538
28.382538
0.872122
318
21.869408
31.890919
31.890919
0.844921
319
22.021996
33.792290
33.792290
0.810507
320
21.682329
32.085899
32.085899
0.852052
321
21.485136
30.800472
30.800472
0.852270
322
21.554802
27.656782
27.656782
0.838214
323
21.582949
30.987406
30.987406
0.874131
324
21.684252
27.577974
27.577974
0.854570
325
21.649443
28.866381
28.866381
0.874095
326
21.800903
28.920055
28.920055
0.830214
327
21.779699
27.880869
27.880869
0.833062
328
21.485964
33.646626
33.646626
0.842734
329
21.803717
32.444542
32.444542
0.837914
330
21.330885
25.325987
25.325987
0.836160
331
21.149656
27.959869
27.959869
0.837313
332
21.619190
32.521076
32.521076
0.829198
333
21.277458
26.951662
26.951662
0.868845
334
21.019983
28.715425
28.715425
0.822239
335
20.938623
27.484371
27.484371
0.841657
336
21.112318
31.391222
31.391222
0.874959
337
21.314363
31.533684
31.533684
0.846035
338
21.258247
34.287395
34.287395
0.857185
339
21.088907
31.995419
31.995419
0.889313
340
21.060400
25.391491
25.391491
0.850021
341
21.191385
29.911272
29.911272
0.846361
342
21.239538
26.181438
26.181438
0.832716
343
20.908230
27.911318
27.911318
0.866273
344
20.663805
26.280628
26.280628
0.821167
345
20.665287
28.996469
28.996469
0.838689
346
20.797508
28.352690
28.352690
0.863336
347
20.920387
27.500000
27.500000
0.862113
348
20.676994
28.133013
28.133013
0.830970
349
20.972519
29.595919
29.595919
0.836804
350
20.676264
32.270893
32.270893
0.806910
351
20.843134
27.261044
27.261044
0.838302
352
20.774733
33.518471
33.518471
0.863284
353
20.609516
27.182146
27.182146
0.840244
354
20.543415
27.878954
27.878954
0.849199
355
20.385962
28.484194
28.484194
0.863096
356
20.492867
25.872372
25.872372
0.807848
357
20.535784
24.678133
24.678133
0.823150
358
20.517176
29.534426
29.534426
0.849546
359
20.470367
26.525784
26.525784
0.847718
360
20.341015
28.980532
28.980532
0.875093
361
20.425688
29.628828
29.628828
0.829189
362
20.493580
28.659624
28.659624
0.824404
363
20.421024
34.806728
34.806728
0.849442
364
20.164429
25.768372
25.768372
0.828425
365
20.437984
29.775827
29.775827
0.848319
366
20.249952
35.833237
35.833237
0.866141
367
20.339994
28.918982
28.918982
0.830377
368
20.284470
29.664816
29.664816
0.851417
369
20.330862
28.552240
28.552240
0.826999
370
20.048162
29.763744
29.763744
0.847629
371
20.228447
35.931984
35.931984
0.872037
372
19.881866
34.702980
34.702980
0.855211
373
20.269339
34.126816
34.126816
0.821539
374
19.969322
37.500626
37.500626
0.855052
375
19.955568
35.221851
35.221851
0.852405
376
20.342201
40.740067
40.740067
0.855306
377
19.976448
27.932146
27.932146
0.847538
378
20.265495
29.188465
29.188465
0.814933
379
19.943855
30.165432
30.165432
0.831588
380
19.907921
28.370747
28.370747
0.811003
381
19.806030
28.747881
28.747881
0.827588
382
19.566782
29.042213
29.042213
0.815333
383
19.962469
27.231894
27.231894
0.812813
384
19.918219
31.876888
31.876888
0.865599
385
19.670238
30.495247
30.495247
0.826349
386
20.088600
28.658175
28.658175
0.835207
387
19.868584
31.086962
31.086962
0.838345
388
19.701395
32.066730
32.066730
0.835094
389
19.621452
28.678385
28.678385
0.808998
390
19.700653
24.997456
24.997456
0.803856
391
19.574583
29.561003
29.561003
0.821678
392
19.870275
28.705059
28.705059
0.811017
393
19.639271
28.516575
28.516575
0.819947
394
19.638912
30.155981
30.155981
0.833960
395
19.592838
29.500837
29.500837
0.841343
396
19.401573
28.972284
28.972284
0.829232
397
19.476780
25.117693
25.117693
0.800549
398
19.415213
27.676556
27.676556
0.817315
399
19.353481
29.587238
29.587238
0.835576
400
19.329996
30.589844
30.589844
0.814549
401
19.388426
28.798521
28.798521
0.796778
402
19.391117
27.288851
27.288851
0.820750
403
19.465256
31.393837
31.393837
0.836368
404
19.367193
32.006042
32.006042
0.836113
405
19.529978
28.048588
28.048588
0.818935
406
19.145683
29.942076
29.942076
0.831204
407
19.258900
29.778469
29.778469
0.806376
408
19.101170
28.129093
28.129093
0.901656
409
19.261408
32.255196
32.255196
0.827982
410
19.313683
28.843971
28.843971
0.821732
411
19.165995
28.464394
28.464394
0.829256
412
19.210093
35.132942
35.132942
0.832075
413
19.264273
30.158085
30.158085
0.823937
414
19.138166
30.775288
30.775288
0.842165
415
18.951040
30.801260
30.801260
0.837962
416
19.047180
33.842201
33.842201
0.877818
417
18.937719
31.111700
31.111700
0.839707
418
18.912231
26.090418
26.090418
0.819153
419
18.981651
25.895103
25.895103
0.810066
420
19.009682
33.740959
33.740959
0.825031
421
19.068413
26.207594
26.207594
0.800084
422
19.138927
27.260050
27.260050
0.826087
423
19.042442
38.317123
38.317123
0.855723
424
18.792263
30.573797
30.573797
0.821726
425
18.854481
30.468676
30.468676
0.832594
426
18.829100
28.783182
28.783182
0.828581
427
18.863146
28.578630
28.578630
0.801191
428
18.836004
27.001785
27.001785
0.794121
429
18.850929
32.737068
32.737068
0.842025
430
18.769400
35.942917
35.942917
0.829415
431
18.702246
26.559973
26.559973
0.800133
432
18.844038
28.286423
28.286423
0.820270
433
18.560753
28.310850
28.310850
0.821260
434
18.794085
31.361650
31.361650
0.812626
435
18.681713
31.365549
31.365549
0.848335
436
18.713829
33.909851
33.909851
0.825864
437
18.699177
27.860062
27.860062
0.840470
438
18.622734
28.208563
28.208563
0.805614
439
18.489710
26.838825
26.838825
0.810255
440
18.533819
30.129074
30.129074
0.844464
441
18.456362
29.615694
29.615694
0.817194
442
18.592129
30.206934
30.206934
0.817924
443
18.747429
29.960958
29.960958
0.821688
444
18.497389
32.355579
32.355579
0.825524
445
18.496330
33.963696
33.963696
0.839722
446
18.526970
27.462112
27.462112
0.807560
447
18.463797
30.403719
30.403719
0.832315
448
18.292990
29.474281
29.474281
0.806896
449
18.556654
27.655251
27.655251
0.793418
450
18.317999
31.544628
31.544628
0.822906
451
18.420084
29.626957
29.626957
0.811272
452
18.295961
27.606194
27.606194
0.810014
453
18.390396
28.956913
28.956913
0.814843
454
18.216801
28.018705
28.018705
0.801500
455
18.274914
30.064688
30.064688
0.815990
456
18.305866
31.102724
31.102724
0.828104
457
18.174770
32.018856
32.018856
0.821333
458
18.388660
29.696213
29.696213
0.824469
459
18.236008
27.600956
27.600956
0.797204
460
18.253284
26.765680
26.765680
0.793056
461
18.275747
30.504982
30.504982
0.797631
462
18.105152
29.080460
29.080460
0.807240
463
18.224957
29.199797
29.199797
0.810738
464
17.977352
30.023697
30.023697
0.806384
465
18.100113
28.936775
28.936775
0.805018
466
18.182266
35.798580
35.798580
0.841455
467
18.075848
29.712828
29.712828
0.808752
468
18.069944
30.718275
30.718275
0.841107
469
18.028542
31.286587
31.286587
0.818650
470
18.039639
27.987673
27.987673
0.790864
471
18.049980
26.684250
26.684250
0.809462
472
18.041315
29.130526
29.130526
0.798955
473
18.041491
34.469986
34.469986
0.832688
474
17.873390
25.967432
25.967432
0.804805
475
18.024952
26.198193
26.198193
0.795796
476
18.130011
26.925938
26.925938
0.808613
477
17.924801
29.644575
29.644575
0.809224
478
18.003313
30.549294
30.549294
0.812313
479
18.007442
34.245888
34.245888
0.839452
480
17.958710
27.796597
27.796597
0.816399
481
17.787209
32.129025
32.129025
0.826926
482
17.836800
29.350918
29.350918
0.806177
483
17.838070
27.451387
27.451387
0.811297
484
17.926260
28.854450
28.854450
0.794694
485
17.941809
29.902784
29.902784
0.810505
486
17.807796
31.584600
31.584600
0.828748
487
18.123257
25.971964
25.971964
0.802340
488
17.923618
30.926918
30.926918
0.807232
489
18.017157
34.957954
34.957954
0.828751
490
17.863190
35.238113
35.238113
0.826141
491
17.845936
28.953009
28.953009
0.810183
492
17.760710
30.270910
30.270910
0.810206
493
17.636686
29.655172
29.655172
0.792839
494
17.646048
27.377762
27.377762
0.795583
495
17.718960
27.934299
27.934299
0.801404
496
17.782598
27.003584
27.003584
0.789616
497
17.744152
29.458599
29.458599
0.805626
498
17.659798
27.889635
27.889635
0.803372
499
17.586069
27.337358
27.337358
0.804841
500
17.465101
27.234043
27.234043
0.796919
501
17.617983
32.175045
32.175045
0.828921
502
17.517565
28.652172
28.652172
0.801894
503
17.607765
28.732475
28.732475
0.797819
504
17.634298
29.661631
29.661631
0.811863
505
17.451544
28.194799
28.194799
0.812533
506
17.535856
27.850016
27.850016
0.802811
507
17.653004
28.458843
28.458843
0.815564
508
17.485518
27.116488
27.116488
0.789522
509
17.657616
31.270426
31.270426
0.816570
510
17.550051
28.537813
28.537813
0.794867
511
17.456673
26.983768
26.983768
0.808765
512
17.461521
32.131729
32.131729
0.811588
513
17.400482
32.266327
32.266327
0.810670
514
17.409887
27.117775
27.117775
0.791690
515
17.444668
28.755241
28.755241
0.802858
516
17.439150
27.350828
27.350828
0.791582
517
17.425848
27.947987
27.947987
0.792039
518
17.374071
29.243031
29.243031
0.813625
519
17.407967
32.309826
32.309826
0.801910
520
17.316252
29.545788
29.545788
0.817326
521
17.313868
28.906019
28.906019
0.810315
522
17.410135
30.232071
30.232071
0.817198
523
17.382107
30.760994
30.760994
0.812093
524
17.381172
27.351587
27.351587
0.790295
525
17.568604
29.831865
29.831865
0.783787
526
17.461163
34.612133
34.612133
0.832215
527
17.314554
26.954687
26.954687
0.793182
528
17.483107
36.418304
36.418304
0.840229
529
17.310074
27.303513
27.303513
0.804458
530
17.327452
32.026627
32.026627
0.813015
531
17.321669
25.205114
25.205114
0.796766
532
17.192396
26.289732
26.289732
0.804069
533
17.361225
30.578878
30.578878
0.805598
534
17.260796
27.188999
27.188999
0.793238
535
17.303936
30.503344
30.503344
0.798804
536
17.307570
28.446278
28.446278
0.808474
537
17.246763
28.973631
28.973631
0.810209
538
17.247349
28.571320
28.571320
0.807079
539
17.199726
29.678768
29.678768
0.805977
540
17.208178
27.343475
27.343475
0.797696
541
17.080887
26.155279
26.155279
0.792784
542
17.105354
28.845409
28.845409
0.812772
543
17.183191
26.199879
26.199879
0.785627
544
17.099640
28.239481
28.239481
0.806001
545
17.122894
27.121138
27.121138
0.783365
546
17.191368
28.866413
28.866413
0.799720
547
17.174528
28.363388
28.363388
0.793286
548
17.141420
26.097416
26.097416
0.797240
549
17.219872
32.718220
32.718220
0.814210
550
17.072456
28.275219
28.275219
0.803153
551
17.203091
26.738821
26.738821
0.793523
552
17.089849
30.534100
30.534100
0.799119
553
17.178310
30.255531
30.255531
0.804183
554
17.007668
25.847328
25.847328
0.792709
555
17.178520
32.843349
32.843349
0.820920
556
16.983553
27.881701
27.881701
0.799397
557
17.079176
29.980043
29.980043
0.815065
558
17.071980
33.022240
33.022240
0.807456
559
17.049711
26.611809
26.611809
0.788952
560
17.131897
28.906601
28.906601
0.797059
561
17.051199
35.488522
35.488522
0.820983
562
17.138382
32.740532
32.740532
0.809891
563
17.273539
35.268623
35.268623
0.817666
564
17.163408
29.226507
29.226507
0.799978
565
17.175001
34.330811
34.330811
0.820431
566
17.057478
29.133484
29.133484
0.797177
567
16.912235
25.953484
25.953484
0.785920
568
16.994179
29.590284
29.590284
0.791411
569
17.040358
30.785744
30.785744
0.809470
570
17.007362
31.357382
31.357382
0.798695
571
17.104525
26.843304
26.843304
0.799187
572
16.939304
27.436810
27.436810
0.796603
573
17.098824
34.378216
34.378216
0.817264
574
16.954121
31.910265
31.910265
0.802028
575
16.983212
28.555937
28.555937
0.800918
576
17.155170
30.026846
30.026846
0.811817
577
17.242094
36.562836
36.562836
0.814757
578
17.060047
28.909149
28.909149
0.802118
579
16.938492
27.774906
27.774906
0.789041
580
16.952663
24.950550
24.950550
0.784761
581
16.997622
29.644119
29.644119
0.792056
582
17.040215
29.872421
29.872421
0.802756
583
16.964943
26.599747
26.599747
0.785250
584
17.105183
35.066528
35.066528
0.816200
585
17.029114
29.482920
29.482920
0.794696
586
16.967745
29.923212
29.923212
0.804836
587
17.068871
36.625725
36.625725
0.813744
588
17.066595
29.443745
29.443745
0.804752
589
17.046106
31.077038
31.077038
0.796184
590
17.081030
36.422691
36.422691
0.825810
591
16.980398
31.546888
31.546888
0.809645
592
16.957745
27.725594
27.725594
0.801067
593
16.934332
28.236074
28.236074
0.790647
594
17.016115
28.282209
28.282209
0.794521
595
17.015324
27.978348
27.978348
0.796241
596
17.026548
28.905725
28.905725
0.794053
597
17.034653
28.156298
28.156298
0.798595
598
17.076344
30.742912
30.742912
0.808406
599
16.990126
28.792730
28.792730
0.801815
600
16.951323
26.981087
26.981087
0.788165
In [129]:
l.recorder.plot_losses()
In [6]:
l.save(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")
In [7]:
!ls models
old_models
old_repr
r_speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False2.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_True.pth
tmp.pth
In [111]:
val_df = get_results_df(val_dl, l.model)
train_df = get_results_df(train_dl, l.model)
In [116]:
df = train_df
In [117]:
df[:][['prediction','target', 'abs_diff','APE']].describe()
Out[117]:
prediction
target
abs_diff
APE
count
185651.000000
185651.000000
1.856510e+05
185651.000000
mean
1.159134
1.184810
1.515271e-01
16.350126
std
1.564008
1.613917
4.039645e-01
49.778847
min
0.010645
0.008491
2.384186e-07
0.000060
25%
0.216299
0.211613
1.029227e-02
3.166194
50%
0.569257
0.565550
3.776443e-02
7.030761
75%
1.389433
1.394052
1.252649e-01
13.793568
max
11.054943
16.089287
1.534635e+01
3100.805420
In [118]:
df = val_df
In [119]:
df[:][['prediction','target', 'abs_diff','APE']].describe()
Out[119]:
prediction
target
abs_diff
APE
count
10000.000000
10000.000000
10000.000000
10000.000000
mean
1.717750
1.601825
0.545405
50.500225
std
1.808425
1.653595
0.805434
75.607666
min
0.012248
0.010685
0.000005
0.001589
25%
0.372329
0.301894
0.054618
8.127786
50%
1.037479
0.970127
0.214614
23.571514
75%
2.149681
2.176448
0.696176
61.788719
max
8.437875
7.522550
5.974429
654.431885
In [120]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[120]:
prediction
target
abs_diff
APE
count
25.000000
25.0
25.000000
25.000000
mean
2.385324
1.0
1.438993
143.899323
std
1.898048
0.0
1.855988
185.598831
min
0.731159
1.0
0.021526
2.152550
25%
0.973027
1.0
0.069987
6.998700
50%
1.585838
1.0
0.585838
58.583771
75%
3.934247
1.0
2.934247
293.424683
max
6.791101
1.0
5.791101
579.110046
In [48]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[48]:
prediction
target
abs_diff
APE
count
10908.000000
10908.000000
10908.000000
10908.000000
mean
1.300379
1.544865
0.348264
17.498978
std
1.134063
1.561985
0.874742
19.460262
min
0.040741
0.080195
0.000011
0.000417
25%
0.521132
0.561271
0.037481
4.959255
50%
0.952309
1.011548
0.103626
11.225842
75%
1.701679
1.967590
0.268848
22.333059
max
7.149683
16.089287
15.834622
476.782654
In [49]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[49]:
prediction
target
abs_diff
APE
count
5388.000000
5388.000000
5388.000000
5388.000000
mean
4.646883
5.080168
0.709873
17.410728
std
2.274305
2.592474
0.847952
34.551174
min
0.041153
0.058317
0.000006
0.000110
25%
2.416280
2.779300
0.175089
4.599086
50%
5.243072
5.502153
0.416792
10.081613
75%
6.300526
6.526127
0.916621
17.963830
max
8.230931
13.560771
6.250570
510.743958
In [50]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[50]:
prediction
target
abs_diff
APE
count
1758.000000
1758.000000
1758.000000
1758.000000
mean
1.041926
1.452486
0.451720
23.524038
std
1.040951
1.578704
0.813260
25.382376
min
0.041223
0.018637
0.000014
0.005628
25%
0.306825
0.347235
0.030223
6.690400
50%
0.681719
0.822524
0.115805
15.353576
75%
1.462868
2.103264
0.442162
31.986423
max
6.201118
9.472253
5.777220
405.513367
In [51]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[51]:
prediction
target
abs_diff
APE
count
21395.000000
21395.000000
21395.000000
21395.000000
mean
1.103823
1.345353
0.325749
19.157253
std
1.122763
1.548089
0.875139
22.390285
min
0.037974
0.029894
0.000007
0.000545
25%
0.387534
0.399394
0.031919
5.276947
50%
0.853066
0.889977
0.086671
11.755613
75%
1.368872
1.568050
0.212368
23.404687
max
8.301966
13.433655
12.628229
364.638855
In [52]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[52]:
prediction
target
abs_diff
APE
count
26827.000000
26827.000000
26827.000000
26827.000000
mean
2.278949
2.478805
0.312914
13.936263
std
1.690642
1.895653
0.552439
35.758522
min
0.041220
0.018075
0.000003
0.000124
25%
0.873289
0.883696
0.049614
3.965816
50%
1.820674
1.956570
0.140420
8.129169
75%
3.415022
3.795769
0.346526
14.133539
max
7.781519
13.558847
6.984952
659.840881
In [53]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[53]:
prediction
target
abs_diff
APE
count
38476.000000
38476.000000
38476.000000
38476.000000
mean
0.599790
0.718080
0.161907
16.824739
std
0.726942
0.959410
0.441344
19.137383
min
0.011075
0.008774
0.000001
0.002214
25%
0.159066
0.163858
0.010282
4.783708
50%
0.350566
0.374172
0.037043
10.798526
75%
0.797795
0.897981
0.118515
21.021241
max
6.714707
9.643795
7.808821
224.782593
In [54]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[54]:
prediction
target
abs_diff
APE
count
80352.000000
80352.000000
8.035200e+04
80352.000000
mean
0.502504
0.618813
1.499550e-01
16.857380
std
0.697265
0.935955
4.647943e-01
20.027716
min
0.011103
0.008491
2.942979e-07
0.000142
25%
0.118478
0.126692
8.149397e-03
4.663052
50%
0.289729
0.305125
2.876675e-02
10.372399
75%
0.639897
0.745241
9.044305e-02
20.579370
max
7.473004
12.004527
8.477986e+00
453.119904
In [56]:
df[(df.interchange + df.tile + df.unroll != 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[56]:
prediction
target
abs_diff
APE
count
185104.000000
185104.000000
1.851040e+05
185104.000000
mean
1.022474
1.185355
2.272181e-01
16.810253
std
1.374168
1.616282
5.942063e-01
23.641901
min
0.011075
0.008491
2.942979e-07
0.000110
25%
0.197577
0.210773
1.411321e-02
4.647739
50%
0.502107
0.562228
5.195722e-02
10.257700
75%
1.149322
1.398968
1.731274e-01
19.895230
max
8.301966
16.089287
1.583462e+01
659.840881
In [76]:
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange + df.tile + df.unroll != 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df2 = df
joint_plot(df2, f"Validation dataset, {loss_func} loss")
In [ ]:
Content source: rbaghdadi/COLi
Similar notebooks: