In [7]:
from os import environ
environ['optimizer'] = 'Adam'
environ['num_workers']= '2'
environ['batch_size']= str(2048)
environ['n_epochs']= '800'
environ['batch_norm']= 'True'
environ['loss_func']='SMAPE'
environ['layers'] = '800 700 600 350 200 180'
environ['dropouts'] = '0.1 '* 6
environ['lr'] = '1e-03'
environ['log'] = 'False'
environ['weight_decay'] = '0.011'
environ['cuda_device'] ='cuda:5'
environ['dataset'] = 'data/speedup_dataset2.pkl'
%run utils.ipynb
In [8]:
train_dl, val_dl, test_dl = train_dev_split(dataset, batch_size, num_workers, log=log)
db = fai.basic_data.DataBunch(train_dl, val_dl, test_dl, device=device)
function329_schedule_13
0
{'computations': {'computations_array': [{'comp_id': 1,
'lhs_data_type': 'p_int32',
'loop_iterators_ids': [2, 3],
'operations_histogram': [[5, 3, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0,
0,
0,
0]],
'rhs_accesses': {'accesses': [{'access': [[1,
0,
0],
[0,
1,
1]],
'comp_id': 0},
{'access': [[1,
0,
0],
[0,
1,
-1]],
'comp_id': 0},
{'access': [[1,
0,
1],
[0,
1,
0]],
'comp_id': 0},
{'access': [[1,
0,
1],
[0,
1,
1]],
'comp_id': 0},
{'access': [[1,
0,
1],
[0,
1,
-1]],
'comp_id': 0},
{'access': [[1,
0,
-1],
[0,
1,
0]],
'comp_id': 0},
{'access': [[1,
0,
-1],
[0,
1,
1]],
'comp_id': 0},
{'access': [[1,
0,
-1],
[0,
1,
-1]],
'comp_id': 0},
{'access': [[1,
0,
0],
[0,
1,
0]],
'comp_id': 0}],
'n': 9}}],
'n': 1},
'inputs': {'inputs_array': [{'data_type': 'p_int32',
'input_id': 0,
'loop_iterators_ids': [0, 1]}],
'n': 1},
'iterators': {'iterators_array': [{'it_id': 2,
'lower_bound': 1,
'upper_bound': 1048575},
{'it_id': 3,
'lower_bound': 1,
'upper_bound': 63},
{'it_id': 0,
'lower_bound': 0,
'upper_bound': 1048576},
{'it_id': 1,
'lower_bound': 0,
'upper_bound': 64}],
'n': 4},
'loops': {'loops_array': [{'assignments': {'assignments_array': [], 'n': 0},
'loop_id': 0,
'loop_it': 2,
'parent': -1,
'position': 0},
{'assignments': {'assignments_array': [{'id': 1,
'position': 0}],
'n': 1},
'loop_id': 1,
'loop_it': 3,
'parent': 0,
'position': 0}],
'n': 2},
'seed': 329,
'type': 2}
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/data/scratch/henni-mohammed/speedup_model/src/data/loop_ast.py in tile(self, loop_id, factor)
260 try:
--> 261 while loop.iterator.id != loop_id:
262 loop = loop.children[0]
AttributeError: 'Computation' object has no attribute 'iterator'
During handling of the above exception, another exception occurred:
NameError Traceback (most recent call last)
<ipython-input-8-ccbb4c277821> in <module>
----> 1 train_dl, val_dl, test_dl = train_dev_split(dataset, batch_size, num_workers, log=log)
2
3 db = fai.basic_data.DataBunch(train_dl, val_dl, test_dl, device=device)
<ipython-input-7-fa1393f8fd16> in train_dev_split(dataset, batch_size, num_workers, log, seed)
108
109 test_size = validation_size = 10000
--> 110 ds = DatasetFromPkl(dataset, maxsize=None, log=log)
111
112 indices = range(len(ds))
/data/scratch/henni-mohammed/speedup_model/src/data/dataset.py in __init__(self, filename, normalized, log, maxsize)
102 program = self.programs[self.program_indexes[i]]
103
--> 104 self.X.append(program.add_schedule(self.schedules[i]).__array__())
105
106
/data/scratch/henni-mohammed/speedup_model/src/data/loop_ast.py in add_schedule(self, schedule)
273 def add_schedule(self, schedule):
274
--> 275 return Loop_AST(self.name, self.dict_repr, schedule)
276
277 def dtype_to_int(self, dtype):
/data/scratch/henni-mohammed/speedup_model/src/data/loop_ast.py in __init__(self, name, dict_repr, schedule)
218
219 if self.schedule:
--> 220 self.apply_schedule()
221
222
/data/scratch/henni-mohammed/speedup_model/src/data/loop_ast.py in apply_schedule(self)
232 if type_ == 'tiling' and binary_schedule[1] == 1:
233 for loop_id, factor in zip(params, factors):
--> 234 self.tile(loop_id, factor)
235
236 elif type_ == 'interchange' and binary_schedule[0] == 1:
/data/scratch/henni-mohammed/speedup_model/src/data/loop_ast.py in tile(self, loop_id, factor)
269 from pprint import pprint
270 pprint(self.dict_repr)
--> 271 exit(1)
272
273 def add_schedule(self, schedule):
NameError: name 'exit' is not defined
In [3]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]
model = None
if batch_norm:
model = Model_BN(input_size, output_size, hidden_sizes=layers_sizes, drops=drops)
else:
model = Model(input_size, output_size)
if loss_func == 'MSE':
criterion = nn.MSELoss()
elif loss_func == 'MAPE':
criterion = mape_criterion
elif loss_func == 'SMAPE':
criterion = smape_criterion
l = fai.Learner(db, model, loss_func=criterion, metrics=[mape_criterion, rmse_criterion])
if optimizer == 'SGD':
l.opt_func = optim.SGD
In [4]:
l = l.load(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")
In [4]:
l.lr_find()
l.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [4]:
l.fit_one_cycle(int(environ['n_epochs']), float(environ['lr']))
Total time: 38:17
epoch
train_loss
valid_loss
mape_criterion
rmse_criterion
1
104.661942
102.300491
287.807861
3.067903
2
95.419327
96.622795
343.448120
3.159903
3
93.724846
95.814415
358.764221
3.266283
4
92.683197
94.970634
364.280121
4.167940
5
79.877151
75.469284
394.043427
4.615024
6
72.907417
71.566765
378.542206
4.514072
7
69.528313
69.789955
373.386292
4.091285
8
67.515762
69.398399
227.227951
3.597232
9
66.386208
67.851303
332.230042
3.537573
10
65.443367
67.711411
187.448212
2.668400
11
64.757523
67.260681
169.101669
2.806999
12
64.312454
65.746140
181.799454
3.179370
13
63.408176
64.634590
179.363205
3.313148
14
63.041748
64.982162
178.473602
3.392530
15
62.523315
64.177925
172.089874
2.984686
16
62.235004
64.324585
168.823624
3.476168
17
61.862000
64.127190
195.083923
4.007714
18
61.721405
63.234776
150.461853
2.163457
19
61.201363
63.207737
153.427155
2.036991
20
60.979752
62.761513
158.769318
2.374927
21
60.773315
61.511395
152.390259
2.316116
22
60.207146
60.094212
143.462357
1.790007
23
59.448761
60.113243
154.291367
2.070540
24
59.000278
59.351776
147.305847
1.757399
25
58.414322
57.866798
140.695786
1.498913
26
57.787773
57.987343
145.372757
1.494222
27
57.437630
59.036324
149.114487
1.439605
28
57.133759
59.246983
149.314529
1.448927
29
56.873142
56.965801
153.228333
1.455193
30
56.486229
57.941837
156.950516
1.327109
31
56.154205
58.173969
156.745132
1.365756
32
55.974373
57.153961
145.573135
1.335501
33
55.483322
56.434643
152.174332
1.304802
34
55.551369
55.601543
149.132156
1.300819
35
55.141087
54.656437
143.004578
1.306897
36
54.809830
53.274967
139.050888
1.292370
37
54.907505
54.076538
142.654037
1.332063
38
54.504459
55.915386
143.768600
1.353823
39
54.128902
53.436577
135.896545
1.325468
40
54.017067
53.831738
150.185028
1.290974
41
53.862091
55.840027
147.336594
1.314821
42
53.527283
54.076736
141.768478
1.336015
43
53.973713
52.882977
141.822510
1.273595
44
52.898853
53.404026
131.541824
1.294747
45
53.102123
52.277805
141.648468
1.245796
46
53.246372
52.223602
129.714691
1.274027
47
52.873920
53.640949
146.999481
1.241097
48
52.327316
51.982662
135.208038
1.196274
49
51.993839
51.716045
143.790405
1.167353
50
50.055206
49.669273
133.191376
1.143145
51
48.487289
48.507362
124.528709
1.138957
52
47.273788
50.232986
136.484650
1.126437
53
45.483864
45.788563
107.178497
1.091820
54
44.432549
45.138218
109.491188
1.047012
55
43.645912
44.297531
98.094994
1.028188
56
44.058010
44.367390
100.613739
1.031231
57
43.517994
44.877281
111.165077
0.984178
58
42.035976
43.166622
99.819710
0.931742
59
41.363792
47.001892
116.207237
0.945334
60
41.205067
40.212219
76.860603
0.928649
61
40.342945
40.350357
88.598007
0.842591
62
39.861126
43.915749
96.860260
0.951487
63
41.569279
49.137905
138.928375
1.038956
64
38.656174
38.269722
68.168587
0.821211
65
37.006947
38.081745
69.554688
0.809022
66
37.259399
42.006481
95.898788
0.919283
67
35.104584
33.845257
55.630444
0.751297
68
39.174126
45.108814
110.103287
0.879789
69
35.597847
40.795883
87.066353
0.849809
70
35.188492
45.994675
116.838234
0.895428
71
35.482456
41.519245
97.868591
0.838134
72
33.699818
41.204712
89.737938
0.779851
73
33.783867
40.763058
90.355843
0.799976
74
32.244606
36.947170
68.910339
0.784361
75
32.478546
39.241833
90.998299
0.794771
76
31.815594
37.286182
73.203354
0.767091
77
33.981007
46.702396
137.475922
1.037940
78
31.130989
39.411957
83.975334
0.777157
79
30.922899
40.445171
83.067902
0.791278
80
31.417219
38.267086
76.217957
0.783134
81
30.416584
37.150749
67.254387
0.795933
82
30.664265
42.865334
111.566460
0.906735
83
30.287697
37.563713
80.684685
0.763742
84
29.169527
37.883106
83.988747
0.785183
85
29.190128
35.485405
70.298973
0.761887
86
29.431238
37.160667
76.099060
0.770274
87
29.209200
36.153896
72.404541
0.765580
88
29.188496
34.842072
64.552872
0.753097
89
28.581736
34.654217
63.075729
0.726597
90
28.907957
38.962070
82.353249
0.789698
91
28.645546
36.130527
70.361160
0.778512
92
29.307817
40.542877
96.766304
0.826857
93
28.179117
37.192131
80.648216
0.781014
94
27.754902
35.200527
68.819534
0.747284
95
27.376822
41.386070
90.189461
0.855115
96
27.022631
36.052956
69.341690
0.773236
97
26.875269
35.434193
68.635651
0.799628
98
26.828203
35.232845
63.945633
0.729966
99
26.919025
34.566429
64.825203
0.774730
100
29.198364
41.009350
106.314301
0.906549
101
26.292440
30.752172
55.480087
0.736762
102
26.095945
35.301109
73.481239
0.757626
103
25.602512
34.754528
67.606255
0.814907
104
25.496876
34.207527
71.713341
0.757177
105
25.737532
44.820004
133.914154
0.959571
106
26.491989
38.033924
79.535103
0.842651
107
25.981333
39.242008
105.115875
0.862244
108
25.157051
32.971642
63.425369
0.784090
109
24.558952
33.757038
60.602299
0.739804
110
24.820776
33.399345
63.863464
0.750734
111
24.540861
30.890497
56.400993
0.727919
112
24.783148
39.267643
85.883987
0.848621
113
24.548325
33.016968
68.912323
0.764989
114
23.981777
36.378086
85.981781
0.826478
115
23.759510
33.783676
63.932438
0.753709
116
24.417717
31.576559
55.537151
0.719969
117
23.784721
34.481113
72.623199
0.757415
118
24.032053
35.183781
68.994904
0.786393
119
23.268295
33.948925
73.139702
0.741098
120
23.819862
33.930164
73.061981
0.767480
121
23.345638
35.781086
79.135872
0.832450
122
23.038601
36.972401
83.576294
0.910006
123
22.931234
30.959219
53.397114
0.746959
124
22.582979
37.205772
88.250519
0.823620
125
22.723442
34.164982
73.145271
0.765667
126
22.312342
33.039394
61.369057
0.767549
127
22.131725
33.849812
68.069473
0.770734
128
22.783657
43.646027
135.995071
1.051347
129
22.267717
31.346838
53.809063
0.745337
130
22.263374
34.920830
73.171074
0.788506
131
22.826796
37.179886
91.977150
0.830106
132
21.739532
32.454044
64.530495
0.771560
133
21.694023
35.031502
69.640236
0.925748
134
21.653934
31.773487
58.949425
0.785699
135
21.722937
36.371319
75.647873
0.814283
136
21.473295
36.882618
82.766670
0.837424
137
21.381804
32.407341
63.721237
0.765168
138
21.037113
32.611954
63.101082
0.765846
139
21.464874
33.156094
59.370163
0.723302
140
20.891520
37.794147
112.588371
0.894038
141
20.496185
32.480045
59.794239
0.694864
142
20.548031
37.449398
84.702400
0.822456
143
20.210930
38.629181
100.650192
0.850540
144
20.725042
37.834850
85.490715
0.858039
145
20.189522
35.133160
85.189323
0.813445
146
20.153017
39.193287
96.159355
0.879081
147
19.847898
29.531868
56.426586
0.677601
148
19.761581
42.129154
120.716011
1.034466
149
19.672522
32.375713
64.398384
0.713867
150
20.593376
34.283028
62.555462
0.777457
151
19.745571
36.027874
78.318108
0.776213
152
19.812349
33.375740
61.304756
0.787887
153
19.362700
32.409920
60.848419
0.735114
154
19.570093
34.876751
71.047813
0.810617
155
19.425051
33.068626
66.291054
0.742795
156
19.045027
37.307568
79.043678
0.866234
157
18.865076
37.361233
81.721832
0.833951
158
19.093390
33.356152
67.107307
0.788253
159
19.307087
30.742373
62.755524
0.688908
160
18.984171
35.436932
73.017296
0.790231
161
18.632318
40.033539
101.983307
0.811729
162
18.878891
36.239338
98.618774
0.979079
163
18.403385
35.462189
88.138664
0.795846
164
18.106642
35.180569
74.765259
0.739391
165
18.505413
33.306362
65.094505
0.745173
166
19.002251
40.978645
100.525909
0.931971
167
18.192255
33.429272
72.455231
0.727496
168
18.261755
37.701862
95.572548
0.831773
169
18.915636
42.487732
117.156212
1.156578
170
18.801580
37.112053
86.959740
0.919734
171
18.571062
35.175114
77.328072
0.767351
172
18.298080
36.872448
82.654427
0.837591
173
17.755650
34.374783
56.978531
0.736356
174
17.878935
38.829765
90.509659
0.848518
175
18.045612
35.368958
69.763222
0.737493
176
17.755033
31.598597
61.123852
0.711139
177
17.811846
32.907314
69.751549
0.822032
178
18.185652
36.846367
77.251816
0.809593
179
17.717474
37.120842
86.447746
0.866205
180
17.403942
37.230167
74.223022
0.807901
181
17.841801
39.048515
79.994209
0.913625
182
17.542055
41.842113
96.250389
0.907040
183
18.061693
38.794624
99.635231
0.921103
184
17.664684
30.544815
52.698326
0.684932
185
17.466208
33.475067
66.160286
0.794323
186
17.048010
35.554726
72.023987
0.797254
187
17.120783
37.959545
77.671486
1.002174
188
17.363098
39.000778
93.504486
0.907761
189
17.242704
36.641697
70.219223
0.772733
190
17.332376
40.995010
97.935852
0.902258
191
16.605289
37.417267
70.279465
0.707466
192
17.778837
37.491486
86.326897
0.961626
193
17.147409
31.587303
60.246082
0.724864
194
17.209435
38.813282
81.308617
0.862118
195
16.874672
40.000622
84.400436
0.823044
196
16.590700
33.899055
61.440830
0.699945
197
16.207685
32.210293
60.616905
0.743438
198
16.513975
33.513378
54.242081
0.674870
199
16.457279
33.936489
67.075310
0.771241
200
16.607021
33.235649
53.912506
0.725068
201
16.226772
30.777582
54.733650
0.671081
202
16.295120
34.147949
57.875702
0.719673
203
16.576664
40.671227
93.130890
0.851940
204
16.225008
29.280163
52.084255
0.682600
205
17.322882
35.688496
79.430199
0.859963
206
16.229267
30.362782
53.395069
0.675089
207
16.223068
31.804893
63.919720
0.765062
208
16.109938
31.523045
62.788395
0.738076
209
15.795094
33.030106
60.178982
0.686414
210
16.262978
36.841732
70.966347
0.751130
211
15.921555
40.079033
87.133110
0.813513
212
15.913036
36.886642
67.965874
0.834556
213
16.228205
35.550343
64.047379
0.701782
214
16.298664
31.984425
62.919086
0.747795
215
16.066477
36.681026
68.726448
0.868704
216
15.699594
35.238907
59.792511
0.801825
217
15.920628
36.859989
69.129364
0.754251
218
16.513695
36.387360
66.647827
0.831069
219
16.284431
39.403873
93.459244
0.933154
220
15.924437
44.703228
107.637184
0.979679
221
15.843354
39.728413
81.514198
0.846963
222
15.688183
40.093975
73.018761
0.847344
223
15.668453
38.013943
67.111938
0.822748
224
15.494933
32.541557
64.056778
0.785795
225
15.463197
34.010338
60.449951
0.738161
226
15.400600
38.782284
90.250053
0.843754
227
15.586081
33.701286
59.488392
0.837940
228
15.433411
36.798489
73.959267
0.909736
229
15.491731
39.602070
71.113853
1.013086
230
15.391937
37.482006
75.331535
0.776898
231
15.435884
38.652855
89.074577
0.836603
232
15.626000
35.137573
64.335411
0.752011
233
15.306082
39.634670
72.660141
0.917822
234
15.355454
32.865147
54.710400
0.789049
235
15.537284
37.812805
76.879021
0.940041
236
15.382627
35.241764
63.682838
0.807723
237
15.830091
40.551067
92.595360
0.974112
238
15.334457
34.133713
69.850311
0.747547
239
15.226704
39.387295
86.528572
0.826174
240
15.211497
35.794567
60.595177
0.834443
241
14.831562
36.866543
66.555702
0.865173
242
15.394850
35.404327
72.592766
0.803318
243
15.180771
32.891193
46.009857
0.764933
244
14.832344
37.297852
64.791862
0.774455
245
14.886884
37.178677
77.568916
0.945506
246
15.114301
37.307392
80.158997
0.812299
247
15.128016
35.778633
84.826706
0.897786
248
15.161243
38.432114
69.448364
0.790683
249
14.738536
33.209286
53.236851
0.753746
250
14.930013
38.159508
83.439903
0.841344
251
14.884100
30.120174
49.279636
0.697899
252
14.682936
38.194454
96.485626
0.820101
253
14.780913
35.699883
72.489815
0.830659
254
14.854465
36.435501
69.882416
0.809575
255
14.783436
40.635185
85.237053
0.913232
256
14.728690
35.541378
58.406082
0.758300
257
14.614870
37.259666
58.624901
0.786327
258
14.479601
34.260586
55.653889
0.752553
259
14.707333
37.941208
70.464417
0.931919
260
15.263060
33.869690
63.190929
0.776068
261
14.630364
36.586292
76.607384
0.854821
262
14.916113
33.110016
59.844917
0.945104
263
14.362852
36.640053
66.169907
0.854198
264
14.512224
35.944500
62.924007
0.815120
265
14.518485
35.471848
68.819427
0.837024
266
14.372441
35.790531
60.089851
0.850061
267
14.297349
42.501862
85.678368
1.031195
268
14.232774
38.601162
70.710434
0.954227
269
14.197763
34.126308
56.835575
0.731181
270
14.483047
47.907314
125.127914
1.218717
271
14.358400
40.992931
89.989151
0.912991
272
14.276005
37.163986
66.525978
0.792247
273
14.210490
36.434368
63.272625
0.850080
274
14.159444
44.882515
91.975616
1.067337
275
14.065702
39.301601
73.856964
0.855303
276
14.217259
38.020649
67.462250
0.911972
277
14.303395
38.757404
78.539825
0.899052
278
14.090721
35.979431
63.161999
0.859901
279
14.020180
38.949642
73.715523
0.900332
280
14.180801
40.710384
77.671898
1.073529
281
14.162584
37.344551
66.635529
0.968039
282
14.120151
36.017788
61.244427
0.808834
283
14.066149
41.633064
96.979980
0.917238
284
13.708919
38.084175
76.774963
0.782097
285
13.701283
34.544674
67.740196
0.740298
286
14.121503
40.151852
78.157623
0.851001
287
13.746382
41.040176
79.097305
1.008487
288
14.025649
34.821995
60.317513
0.742577
289
13.799110
36.059208
61.488655
0.848807
290
13.765931
38.428928
79.509003
0.914153
291
13.690766
38.088482
66.529884
1.011872
292
13.838598
36.949024
66.586395
0.840802
293
13.717721
35.707787
59.181881
0.801271
294
13.687765
33.001259
44.679905
0.788051
295
13.792114
40.192196
74.919266
0.847585
296
13.686102
41.424580
73.893829
0.957933
297
13.661887
35.499836
59.638813
0.749577
298
13.750915
35.873219
59.855495
0.868022
299
14.228426
34.791737
56.038673
0.745508
300
13.595055
33.988583
53.038189
0.772725
301
13.597076
36.386162
60.771076
0.810506
302
13.471545
37.539658
60.308643
0.823721
303
13.574639
39.752590
71.828415
1.088241
304
13.920475
41.222118
94.977074
0.984618
305
13.732122
36.738060
57.242893
0.841365
306
13.862490
37.092247
77.337212
0.967600
307
13.850560
33.499344
52.776295
0.769674
308
13.624640
37.554020
65.077057
0.810833
309
13.426353
43.089401
73.330788
0.992391
310
13.450548
37.321701
60.932724
0.748675
311
13.483907
38.794670
84.988472
0.891129
312
13.798834
35.658306
63.797874
0.823758
313
13.433642
32.610828
50.846367
0.731271
314
13.343544
32.993988
48.117706
0.767232
315
13.474988
36.691113
63.691349
0.898449
316
13.429296
38.860413
83.664520
1.031983
317
13.361508
35.176506
56.815937
0.826919
318
13.278620
39.751789
69.268814
1.064779
319
13.307871
35.232132
57.942730
0.831965
320
13.276241
36.155300
58.058113
0.821056
321
13.240404
33.182152
54.350239
0.806310
322
13.175077
32.868008
54.302299
0.803044
323
13.213413
32.984749
61.462555
0.823345
324
13.371449
37.937538
64.657791
0.875779
325
13.046414
36.068367
55.396469
0.806581
326
13.225363
36.366241
61.848137
0.878051
327
13.157030
35.214005
59.394054
0.781056
328
13.134717
38.963882
81.157677
0.904028
329
13.073193
41.658489
75.447479
0.969161
330
13.090325
33.283485
55.087288
0.738971
331
13.296698
36.199852
52.306442
0.823753
332
12.995165
36.111629
54.676262
0.848354
333
13.002842
37.476299
73.416077
1.084120
334
13.161036
38.306271
63.197983
1.130297
335
13.103834
33.749607
58.727905
0.879297
336
13.002410
34.198273
53.945587
0.780956
337
12.858998
38.131855
61.478062
0.906925
338
13.144301
41.692314
94.228462
1.098698
339
12.994774
36.444580
64.170647
0.848261
340
12.978182
33.547558
51.029385
0.807135
341
12.942981
34.245819
56.936619
0.794053
342
13.055573
35.550556
52.780064
0.818262
343
13.013928
38.767193
65.546707
0.864412
344
12.866686
34.855820
52.471138
0.797303
345
12.878273
36.959450
57.847755
0.868023
346
12.816094
36.925980
69.540977
0.864080
347
12.851155
39.476318
62.689823
0.911457
348
12.788852
34.329052
59.492905
0.849512
349
12.791994
37.979401
71.137390
1.068153
350
12.813991
37.190056
64.070984
0.892292
351
12.796708
37.938084
55.218567
0.886561
352
12.867328
42.055801
66.523285
1.083743
353
12.742359
35.659542
55.456100
0.864327
354
12.759057
38.184986
62.784393
0.922698
355
12.609703
40.404392
80.213448
0.853581
356
12.634410
36.687870
54.765980
0.836707
357
12.669363
33.221905
49.995766
0.781338
358
12.808789
37.109089
60.381718
1.026994
359
12.849763
34.593662
51.261200
0.813038
360
12.704800
40.702755
76.830826
1.118681
361
12.774818
32.621395
45.870682
0.705057
362
12.820094
37.081005
61.898838
0.785278
363
12.538405
33.900280
59.709343
0.812203
364
12.795860
37.484512
65.842865
0.897190
365
12.756648
33.978382
58.504395
0.972324
366
12.580325
35.424702
71.121269
0.976936
367
12.651705
32.676121
49.024632
0.753516
368
12.580230
34.890976
58.922325
0.809124
369
12.632515
34.719036
58.124001
0.830313
370
12.639567
35.679478
52.310707
0.768535
371
12.575728
35.321842
58.099461
0.799447
372
12.619104
37.118923
61.161499
0.900874
373
12.565275
36.026432
57.729786
0.820253
374
12.490048
35.925011
59.125061
0.864539
375
12.570393
33.363773
57.605824
0.787726
376
12.449144
32.254482
48.738400
0.757104
377
12.311611
32.634571
45.483692
0.801913
378
12.426764
33.488895
53.547668
0.757552
379
12.469316
37.325195
66.363983
0.880980
380
12.719226
36.013229
55.676968
0.851353
381
12.528697
35.423668
53.485611
0.859685
382
12.501586
36.498070
63.269669
0.812368
383
12.361664
34.617161
53.280651
0.822256
384
12.323334
34.713444
54.018089
0.801539
385
12.269219
33.173176
48.485111
0.790735
386
12.250194
35.126713
57.041138
0.797915
387
12.331409
35.184475
57.862411
0.855689
388
12.298031
33.965019
54.458473
0.797578
389
12.303589
33.125614
48.859356
0.946284
390
12.239066
35.905819
56.073475
0.822793
391
12.467511
32.162380
43.686119
0.697829
392
12.286215
36.352661
61.190887
0.834126
393
12.211934
37.417011
67.354195
0.825209
394
12.502189
34.833328
56.821518
0.839059
395
12.167605
31.657478
44.316624
0.720613
396
12.202771
34.178574
56.687813
0.785390
397
12.206116
34.017262
47.890656
0.779165
398
12.386162
36.233345
57.043720
0.833403
399
12.288252
33.675564
52.371277
0.773573
400
12.199110
34.479584
54.236267
0.765616
401
12.106255
33.104267
41.109093
0.753728
402
12.192243
37.162727
60.973217
0.877134
403
12.166115
36.112846
62.293644
0.876282
404
12.051862
34.599560
56.568539
0.831527
405
12.152299
35.801731
56.881062
0.939389
406
12.071044
34.257393
55.462967
0.768282
407
12.301889
33.355858
46.524151
0.811592
408
11.988328
35.438625
55.166687
0.862791
409
12.168221
36.728107
62.928780
0.915328
410
12.041191
34.772251
61.766026
0.891146
411
12.104047
32.878361
56.299400
0.769873
412
12.152422
33.359688
51.822353
0.735358
413
12.132710
34.363724
46.458733
0.788493
414
12.005080
31.935381
50.359497
0.745705
415
12.007692
31.619043
54.513432
0.744616
416
11.938269
32.662510
55.223244
0.770022
417
11.943339
34.436939
53.123196
0.820285
418
12.150605
35.462070
56.342068
0.756428
419
12.002385
32.409164
47.016636
0.768401
420
12.045143
34.849861
58.498520
0.848078
421
11.965784
34.549320
50.797157
0.833625
422
11.943920
33.853745
56.128063
0.853530
423
11.880611
34.870552
54.711914
0.784233
424
11.995658
32.782089
48.833664
0.821170
425
11.914958
31.920240
51.242035
0.776090
426
11.805329
30.355015
42.911594
0.711182
427
11.935860
33.408634
50.508858
0.749738
428
11.772082
34.430420
50.202419
0.809431
429
11.736680
33.981133
54.370186
0.755592
430
11.794690
32.499371
51.128971
0.743649
431
11.847500
31.108654
44.055923
0.693529
432
11.816026
35.574501
53.410114
0.848332
433
11.903443
34.996403
53.236588
0.845075
434
11.948842
32.947636
54.303482
0.711090
435
12.243643
36.320072
58.118374
0.933348
436
11.913749
34.506680
54.372795
0.860175
437
11.837273
38.140785
63.231224
0.895437
438
11.779114
34.939205
54.261112
0.723276
439
11.615091
33.369667
55.079437
0.756075
440
11.931456
34.537476
60.722855
0.948659
441
11.787133
35.227863
57.149357
1.033259
442
11.780608
31.142134
47.562675
0.805610
443
11.636442
32.187580
48.354527
0.770644
444
11.637978
32.323700
46.854424
0.702546
445
11.903340
35.190704
50.646305
0.852460
446
11.571512
33.237946
56.689281
0.777469
447
11.599289
33.937153
52.734787
0.755253
448
11.579601
32.428337
47.396729
0.811529
449
11.744120
32.505245
51.765270
0.840782
450
11.666202
34.824070
57.006863
0.816639
451
11.648829
34.638466
59.238213
0.882056
452
11.793554
30.710043
45.856174
0.694261
453
11.659192
33.148499
54.069881
0.805099
454
11.637027
33.141830
55.173813
0.788581
455
11.636176
29.656418
40.620483
0.758811
456
11.555015
31.963772
46.010929
0.753087
457
11.522563
32.628345
40.730591
0.791435
458
11.672043
35.127548
55.553432
0.870043
459
11.552064
34.028255
50.035458
0.825540
460
11.573618
32.549507
47.255211
0.773338
461
11.536166
37.656078
67.344048
0.983174
462
11.485510
34.905525
53.082489
0.876960
463
12.100968
32.509598
45.874111
0.738427
464
11.645988
35.274570
53.998070
0.845883
465
11.577806
35.863506
53.654526
0.889842
466
11.482991
36.541199
54.633507
0.946504
467
11.419836
35.270477
57.667664
0.860766
468
11.445416
32.532867
45.412918
0.808560
469
11.609118
36.200794
53.069351
0.884522
470
11.508022
31.802618
45.102032
0.766375
471
11.555008
32.928139
47.937088
0.790140
472
11.390872
34.186924
49.535130
0.859464
473
11.388084
35.261086
57.839161
0.829830
474
11.583447
35.107101
51.975403
0.859695
475
11.494924
31.022924
49.652603
0.717981
476
11.468680
34.738056
52.113766
0.873243
477
11.391685
32.519497
52.387321
0.794116
478
11.339947
32.575859
46.261452
0.784512
479
11.333695
33.122242
46.442688
0.766425
480
11.354995
33.711990
52.789104
0.851568
481
11.302698
33.098431
45.819736
0.840131
482
11.348128
33.734245
47.610298
0.811764
483
11.292136
34.134396
47.622189
0.842849
484
11.374100
34.323856
50.191330
0.806330
485
11.306168
32.635502
44.956974
0.786807
486
11.278807
33.210289
54.042301
0.774301
487
11.253035
32.103420
49.920570
0.766805
488
11.377320
33.151814
45.958019
0.806843
489
11.287739
31.858513
46.511711
0.782634
490
11.293363
32.641758
48.649235
0.813386
491
11.299816
33.480637
49.293549
0.859578
492
11.321927
32.291782
51.565430
0.823631
493
11.156326
31.785048
43.623890
0.801076
494
11.235230
32.648998
55.805870
0.824433
495
11.259790
34.124096
49.528847
0.877523
496
11.215596
34.206402
45.078144
0.862263
497
11.323204
33.656700
56.458630
0.856608
498
11.231325
32.744850
51.818192
0.813148
499
11.142723
31.251940
47.446789
0.781808
500
11.108985
30.265013
44.391117
0.739530
501
11.166247
34.717888
53.372711
0.837495
502
11.165710
31.448730
45.793495
0.824465
503
11.121704
32.676964
50.460670
0.811395
504
11.204881
33.975521
54.055832
0.895764
505
11.081209
34.690880
55.574245
0.847157
506
11.057419
33.189896
48.075176
0.779893
507
11.131548
29.846952
39.400620
0.726093
508
11.080618
35.202263
51.126720
0.861672
509
11.100417
31.947847
49.098103
0.814798
510
11.080097
31.969969
43.697010
0.803608
511
11.109570
34.695049
51.994610
0.864256
512
11.168578
33.182987
45.564724
0.803472
513
11.033402
36.209171
55.012287
0.869068
514
11.052049
35.058483
55.960583
0.845634
515
10.939099
31.855934
49.691120
0.835719
516
11.056652
32.300129
45.363037
0.753454
517
11.091005
33.233646
50.874031
0.770722
518
10.956681
30.811356
42.861607
0.735481
519
10.981388
31.026829
44.894558
0.785294
520
11.004642
32.647758
46.569706
0.823831
521
10.953099
34.234478
46.780113
0.822051
522
10.973938
31.475555
43.187763
0.757077
523
11.007047
31.275043
42.730217
0.802831
524
10.918464
32.520325
47.954807
0.805515
525
10.886256
32.177811
40.442772
0.752994
526
10.892241
32.174744
51.115650
0.812298
527
10.813769
33.280319
48.492802
0.819348
528
10.909218
30.996675
47.737335
0.793023
529
10.918257
31.962088
45.065788
0.801840
530
10.900455
33.999947
45.241333
0.822240
531
10.880502
32.474720
46.153351
0.731042
532
10.864187
31.137716
43.594692
0.735325
533
10.884331
29.886957
42.479042
0.750811
534
11.131597
31.612997
43.821152
0.800037
535
10.910157
33.998192
52.329884
0.832512
536
10.876169
31.436926
46.153633
0.791093
537
10.821349
32.313305
43.080772
0.761663
538
10.691327
31.478374
44.661980
0.807908
539
10.733187
32.746445
46.056431
0.791561
540
10.881904
31.384829
42.909969
0.732672
541
10.812654
32.902493
47.176476
0.758470
542
10.676754
31.973196
41.508629
0.762530
543
10.729968
30.537022
42.985134
0.754099
544
10.736858
34.341480
48.234833
0.872449
545
10.712658
32.284100
43.297245
0.806100
546
10.783628
30.345379
41.466480
0.757943
547
10.738046
30.782459
40.055653
0.706487
548
10.677889
29.311604
38.942455
0.702269
549
10.657338
33.291080
49.869057
0.842457
550
10.889813
32.747063
45.450943
0.825145
551
10.745845
31.188248
42.008823
0.806400
552
10.660393
33.461025
45.837399
0.882209
553
10.612618
32.400742
44.933666
0.829561
554
10.671711
30.866146
41.817032
0.788256
555
10.718171
33.100689
43.650986
0.816862
556
10.687399
32.339050
47.633419
0.830143
557
10.650363
33.068401
45.152863
0.796859
558
10.658573
31.572496
46.176575
0.850385
559
10.670061
31.468437
43.738743
0.798255
560
10.579093
31.051060
41.064877
0.722195
561
10.547851
32.885674
41.851555
0.786262
562
10.569036
30.776751
42.416008
0.807351
563
10.614475
31.781672
41.541836
0.755998
564
10.582497
32.539455
45.294250
0.834737
565
10.616549
32.310062
45.778389
0.810500
566
10.826120
32.827736
46.637016
0.809630
567
10.542817
35.125851
52.375713
0.841108
568
10.548382
30.902235
43.887844
0.837888
569
10.540369
33.409454
46.572956
0.874809
570
10.436017
31.887850
43.440018
0.800166
571
10.564009
32.483303
50.893322
0.829127
572
10.586737
32.004173
49.518562
0.829193
573
10.549219
31.305660
42.238922
0.788435
574
10.540456
30.608231
42.825375
0.806174
575
10.559407
32.871227
44.891594
0.841798
576
10.577321
31.312138
44.013927
0.760933
577
10.525023
31.896362
44.472618
0.831106
578
10.457129
32.193584
43.791088
0.811496
579
10.528196
30.082069
41.964920
0.781214
580
10.493624
32.932198
40.529228
0.806866
581
10.503120
30.846781
44.609730
0.764951
582
10.519255
30.212406
40.270134
0.705149
583
10.511444
31.407768
44.606392
0.784281
584
10.484232
32.019417
42.804447
0.771105
585
10.473190
31.936001
43.349918
0.800597
586
10.459234
31.801416
41.099403
0.757221
587
10.373753
31.679941
43.897717
0.822217
588
10.370695
32.181934
42.148083
0.758530
589
10.486632
32.252773
45.361038
0.820570
590
10.415092
31.944538
43.218761
0.798378
591
10.418789
32.272865
43.792267
0.816576
592
10.376585
31.896963
41.935593
0.767692
593
10.412916
32.306396
44.512131
0.802180
594
10.337895
32.150246
39.602020
0.750949
595
10.334853
32.682358
43.564617
0.811374
596
10.395268
31.293781
44.229626
0.742424
597
10.328351
30.480072
42.357971
0.761181
598
10.301076
32.719143
45.066521
0.818073
599
10.393769
31.476978
41.709400
0.772905
600
10.309531
30.193087
41.269161
0.771917
601
10.326076
29.917000
44.026421
0.743222
602
10.300266
30.687998
42.873367
0.753437
603
10.182746
32.722397
45.707802
0.753374
604
10.308241
33.241085
43.662582
0.790943
605
10.338623
31.647640
42.061272
0.778817
606
10.302413
31.873144
42.941525
0.794929
607
10.184274
32.177017
41.767799
0.781455
608
10.272506
30.825581
40.719013
0.747567
609
10.248549
32.106339
43.710194
0.780104
610
10.240237
32.319740
44.257793
0.784601
611
10.282104
32.756092
45.696377
0.768920
612
10.390173
31.231035
42.622723
0.791275
613
10.321748
28.594395
39.625011
0.752189
614
10.213476
30.117050
39.224636
0.757631
615
10.147255
31.239347
41.116348
0.783254
616
10.225457
32.293961
42.130096
0.767570
617
10.237810
32.022163
42.237946
0.788319
618
10.241763
31.714756
45.902386
0.777540
619
10.211644
31.169041
44.709095
0.785831
620
10.141590
31.159685
36.629150
0.769987
621
10.203721
31.170979
43.682644
0.779005
622
10.132215
30.309435
39.699505
0.758349
623
10.191387
30.477232
41.082172
0.761230
624
10.202635
31.717291
42.321247
0.782434
625
10.176472
30.269129
41.432739
0.750885
626
10.062234
30.373434
40.190861
0.745283
627
10.148728
30.880890
47.748814
0.758331
628
10.161660
30.149088
40.276867
0.766280
629
10.076932
32.251324
45.949512
0.809276
630
10.014497
30.223806
38.812763
0.753666
631
10.147480
31.401785
41.958405
0.783280
632
10.089665
31.324041
41.294807
0.762392
633
10.146390
32.783264
44.675529
0.834137
634
10.010672
28.839329
39.032829
0.764788
635
10.090048
31.592875
40.087925
0.755136
636
10.167812
31.371788
40.773643
0.764319
637
10.071571
32.520931
43.621811
0.809180
638
10.084064
30.704103
46.330196
0.764378
639
10.082823
31.124193
41.827141
0.789817
640
10.115430
31.676672
40.732712
0.782872
641
10.041951
32.470097
41.621101
0.796224
642
10.047478
30.169109
41.879143
0.768825
643
10.086026
31.368666
41.162083
0.784595
644
10.021622
32.116001
44.177227
0.813651
645
9.993053
30.397690
40.735558
0.783744
646
10.015490
30.631868
40.086662
0.757853
647
10.008081
31.987120
48.596672
0.803098
648
9.974685
31.984331
48.100101
0.816306
649
10.017381
31.248077
40.230179
0.766258
650
10.020640
32.000061
47.333588
0.814199
651
9.997214
29.723112
43.755695
0.795468
652
9.981441
31.737379
41.607422
0.756347
653
9.971068
31.775999
41.492264
0.762575
654
9.992212
30.865385
40.317120
0.790785
655
9.931530
31.564976
41.215267
0.772238
656
9.935065
31.853416
41.622173
0.782915
657
10.056235
32.223610
45.113853
0.816813
658
10.034753
31.116976
41.442284
0.765694
659
9.956408
31.713394
44.925644
0.821500
660
9.956842
31.979059
44.247322
0.834422
661
9.961246
31.715359
43.274761
0.787393
662
9.910225
30.676476
39.511555
0.760025
663
9.920900
30.016319
39.390263
0.773459
664
9.911504
31.230188
42.298832
0.793608
665
9.945103
31.678097
44.985657
0.797827
666
9.942798
31.713078
40.818367
0.762015
667
9.907918
31.293547
42.838490
0.810552
668
9.922390
31.780157
43.648235
0.806912
669
9.892386
30.905590
41.307827
0.750406
670
9.979078
31.842716
42.867973
0.791048
671
9.946197
31.588737
43.264408
0.777233
672
9.865866
31.432087
41.144989
0.775803
673
9.983168
31.914169
43.800617
0.805683
674
9.828781
30.372656
39.472126
0.762793
675
9.831721
32.561714
43.959061
0.804887
676
9.883856
30.674862
41.189682
0.789501
677
9.885988
30.876160
40.452320
0.779910
678
9.869435
32.730408
43.312080
0.777129
679
9.827736
30.116957
40.042072
0.755432
680
9.848602
31.300837
41.809254
0.790091
681
9.834074
31.698816
41.792267
0.784149
682
9.857887
30.991501
41.364262
0.769668
683
9.883680
31.140213
43.657055
0.789945
684
9.847812
30.932934
41.761513
0.770102
685
9.834020
31.624313
42.631992
0.786023
686
9.910093
31.940235
41.694534
0.772361
687
9.810919
30.049749
38.494488
0.741426
688
9.796085
29.700165
38.188263
0.750679
689
9.868774
30.717962
40.764320
0.776954
690
9.903551
33.229519
44.173901
0.802213
691
9.817249
31.294138
40.596779
0.782075
692
9.824275
29.425268
40.771614
0.779332
693
9.842527
29.503506
39.334763
0.769946
694
9.794269
29.239500
37.961208
0.754913
695
9.817785
30.945566
39.964619
0.772701
696
9.817707
29.854034
40.415501
0.780207
697
9.816250
32.428398
43.072540
0.795279
698
9.854343
31.537256
40.777443
0.770251
699
9.796930
31.127625
40.363930
0.769160
700
9.761352
31.748146
42.998409
0.794501
701
9.727106
30.759462
39.631405
0.781783
702
9.818300
31.566833
40.856277
0.764799
703
9.735809
30.461531
40.020725
0.769803
704
9.709221
32.073898
41.426407
0.774317
705
9.860830
30.718891
40.711575
0.768392
706
9.753155
31.812387
40.169922
0.774792
707
9.783849
31.443691
42.912380
0.802167
708
9.753920
29.241016
38.127773
0.752105
709
9.736170
31.600452
41.892467
0.771903
710
9.778872
31.106777
40.985779
0.762869
711
9.781320
32.161392
42.339863
0.758567
712
9.723782
31.131306
41.367077
0.770838
713
9.713279
29.560587
40.657448
0.785997
714
9.705586
29.352640
38.429882
0.743384
715
9.712332
30.916153
40.737099
0.764970
716
9.698907
31.785757
40.028885
0.784439
717
9.636447
31.517471
42.659718
0.802592
718
9.740340
30.049469
39.182392
0.754952
719
9.728582
30.617294
39.796364
0.769710
720
9.678654
30.737925
42.368881
0.803600
721
9.711122
30.299809
41.268520
0.793355
722
9.619991
31.085951
39.955555
0.782532
723
9.751316
29.779947
40.341827
0.783709
724
9.674048
31.290476
41.474079
0.778345
725
9.672349
30.179070
39.595337
0.771420
726
9.699500
29.463131
38.146900
0.762972
727
9.669084
30.620199
39.680897
0.758486
728
9.628733
29.497955
39.276901
0.781883
729
9.671335
29.602190
39.459969
0.771758
730
9.661795
30.792170
40.036633
0.762638
731
9.684377
30.465631
39.869690
0.766051
732
9.668206
31.793076
41.195175
0.770406
733
9.652674
30.564663
40.454071
0.756535
734
9.625160
32.018318
42.422970
0.780837
735
9.570590
30.961109
39.625145
0.772827
736
9.620033
30.734894
39.833469
0.774007
737
9.586766
30.208803
39.739986
0.768127
738
9.670497
30.356806
38.895210
0.774671
739
9.615580
31.292831
40.804581
0.753298
740
9.560591
30.176247
39.565033
0.767050
741
9.573107
30.809959
40.802887
0.775361
742
9.636915
29.830799
39.214188
0.756181
743
9.597157
30.645247
40.622963
0.765673
744
9.617984
30.461678
39.248043
0.771196
745
9.630928
30.437168
40.714092
0.771946
746
9.646561
30.845448
41.071667
0.788137
747
9.542443
30.773481
40.134399
0.769803
748
9.590741
30.436485
40.005524
0.761820
749
9.679118
29.119307
38.218044
0.758506
750
9.595251
30.149401
39.743889
0.785995
751
9.565324
29.042376
38.672707
0.766602
752
9.583814
31.691635
40.837067
0.788996
753
9.606961
30.748987
39.166336
0.774745
754
9.601628
31.929968
41.832760
0.789794
755
9.604128
29.943163
39.956245
0.778257
756
9.570592
31.073618
41.455475
0.784281
757
9.626746
32.435730
42.175480
0.786749
758
9.531698
31.319565
40.691662
0.772488
759
9.540717
30.583803
39.801903
0.769506
760
9.625671
30.024263
39.508583
0.778451
761
9.571132
30.511950
39.708309
0.766924
762
9.509952
30.981873
40.450661
0.775295
763
9.578967
30.622362
39.369583
0.770379
764
9.573689
30.510088
40.041912
0.762529
765
9.509113
30.627520
39.801033
0.782817
766
9.528758
30.228762
39.832619
0.777744
767
9.546545
31.797190
42.300098
0.764829
768
9.582674
31.294750
40.103764
0.772561
769
9.593339
30.542294
39.817646
0.742610
770
9.605639
30.074352
39.240555
0.767354
771
9.710602
30.982210
39.419342
0.773902
772
9.537987
31.435150
40.593403
0.786945
773
9.590886
31.206141
40.566963
0.767543
774
9.567395
31.749069
40.750977
0.773995
775
9.542539
29.619267
38.685436
0.763386
776
9.527358
30.614010
39.639404
0.766361
777
9.539520
31.429237
40.269295
0.768710
778
9.559960
30.711203
40.088882
0.775561
779
9.513242
30.790771
39.574284
0.782673
780
9.529019
30.343843
39.506577
0.774377
781
9.466898
30.756144
39.555008
0.775904
782
9.496929
29.230103
38.275623
0.754839
783
9.530871
30.259159
39.312073
0.762919
784
9.511191
31.637276
40.612881
0.772637
785
9.542443
31.045744
39.771599
0.775636
786
9.636745
29.079962
38.024403
0.754148
787
9.520823
29.941092
38.866543
0.766303
788
9.544065
31.280649
40.372971
0.768853
789
9.571733
31.422525
40.813103
0.764549
790
9.478342
30.363409
39.369511
0.768250
791
9.537759
29.589869
38.728600
0.770246
792
9.539391
30.926800
39.580605
0.772210
793
9.550218
31.530407
41.006973
0.771047
794
9.527091
30.411278
39.702164
0.764613
795
9.571325
31.526840
40.725555
0.774288
796
9.554912
30.172140
39.190933
0.775775
797
9.538812
31.068150
40.053707
0.787963
798
9.537267
30.900993
40.010624
0.773879
799
9.507723
30.391706
39.997898
0.767159
800
9.558806
32.454838
42.460445
0.792395
In [11]:
l.recorder.plot_losses()
In [6]:
l.save(f"r_speedup_{optimizer}_batch_norm_{batch_norm}_{loss_func}_nlayers_{len(layers_sizes)}_log_{log}")
In [7]:
!ls models
old_models
old_repr
r_speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False2.pth
speedup_Adam_batch_norm_True_MAPE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_False.pth
speedup_Adam_batch_norm_True_MSE_nlayers_5_log_True.pth
tmp.pth
In [12]:
val_df = get_results_df(val_dl, l.model)
train_df = get_results_df(train_dl, l.model)
In [13]:
df = train_df
In [14]:
df[:][['prediction','target', 'abs_diff','APE']].describe()
Out[14]:
prediction
target
abs_diff
APE
count
245283.000000
245283.000000
2.452830e+05
245283.000000
mean
1.102805
1.135724
2.135267e-01
22.577551
std
1.295271
1.405682
5.665563e-01
91.278419
min
0.010044
0.008491
2.048910e-07
0.000030
25%
0.280880
0.278690
5.817745e-03
1.646686
50%
0.855459
0.899071
3.095371e-02
5.807508
75%
1.050393
1.036481
1.352153e-01
17.479273
max
8.452207
16.089287
1.541526e+01
5824.463867
In [15]:
df = val_df
In [16]:
df[:][['prediction','target', 'abs_diff','APE']].describe()
Out[16]:
prediction
target
abs_diff
APE
count
10000.000000
10000.000000
10000.000000
10000.000000
mean
1.176222
1.430897
0.442006
42.532043
std
1.231279
1.683147
0.734061
148.393097
min
0.014810
0.014795
0.000040
0.006628
25%
0.389919
0.395193
0.024651
4.690779
50%
0.961758
1.000000
0.134899
22.480320
75%
1.257520
1.621683
0.497272
48.398965
max
7.889497
10.872228
6.232839
5399.246094
In [35]:
df[:][['index','name','prediction','target', 'abs_diff','APE']].to_csv(path_or_buf='./eval_results.csv',sep=';')
In [36]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)][['index','name','prediction','target', 'abs_diff','APE']].to_csv(path_or_buf='./eval_results.csv',sep=';')
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[36]:
prediction
target
abs_diff
APE
count
2074.000000
2074.0
2074.000000
2074.000000
mean
0.998558
1.0
0.050076
5.007599
std
0.208153
0.0
0.202042
20.204212
min
0.603882
1.0
0.000066
0.006628
25%
0.979231
1.0
0.002551
0.255150
50%
0.995987
1.0
0.004715
0.471514
75%
0.997637
1.0
0.022805
2.280509
max
6.001675
1.0
5.001675
500.167511
In [19]:
df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[19]:
prediction
target
abs_diff
APE
count
725.000000
725.000000
725.000000
725.000000
mean
1.462764
1.519225
0.393448
29.481096
std
1.346193
1.422773
0.573493
31.370106
min
0.145803
0.091200
0.000040
0.009375
25%
0.494025
0.508859
0.070696
8.597638
50%
0.991690
0.993141
0.191803
19.992519
75%
2.151304
2.207209
0.496262
41.369717
max
7.768654
9.358214
3.835224
295.724640
In [21]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[21]:
prediction
target
abs_diff
APE
count
281.000000
281.000000
281.000000
281.000000
mean
4.372236
6.044158
1.848176
66.339111
std
2.171317
3.221366
1.231329
163.969162
min
0.904279
0.105179
0.001005
0.060693
25%
1.006489
2.546717
1.032713
18.170967
50%
5.899143
7.389543
1.563098
31.399256
75%
5.972240
8.668690
2.789691
43.133614
max
6.070222
10.872228
4.899988
848.387634
In [22]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[22]:
prediction
target
abs_diff
APE
count
232.000000
232.000000
232.000000
232.000000
mean
0.913173
0.918692
0.347373
122.910507
std
1.018806
1.168210
0.588294
610.275146
min
0.084280
0.018092
0.001015
0.220132
25%
0.260828
0.265846
0.033345
7.204351
50%
0.456005
0.466655
0.101908
19.017271
75%
1.092822
0.956079
0.327444
41.838287
max
5.042413
8.069739
3.409136
5399.246094
In [23]:
df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[23]:
prediction
target
abs_diff
APE
count
868.000000
868.000000
868.000000
868.000000
mean
1.589693
1.748770
0.590289
37.538128
std
1.585954
1.605983
0.793788
50.443039
min
0.168853
0.057130
0.000085
0.043231
25%
0.442154
0.695116
0.105190
12.070397
50%
1.012416
1.241629
0.302328
29.301636
75%
2.136123
2.229650
0.773482
44.885977
max
7.889497
10.137201
4.684408
414.385712
In [24]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[24]:
prediction
target
abs_diff
APE
count
1276.000000
1276.000000
1276.000000
1276.000000
mean
1.962418
2.878316
1.036611
72.382378
std
1.334146
2.108479
1.021500
274.414642
min
0.101093
0.042484
0.000331
0.041299
25%
0.819918
0.977715
0.154067
10.751699
50%
2.072164
2.629361
0.717993
40.390493
75%
2.634203
4.365878
1.818399
50.273042
max
4.919624
9.784180
5.172597
2245.031982
In [25]:
df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[25]:
prediction
target
abs_diff
APE
count
1663.000000
1663.000000
1663.000000
1663.000000
mean
0.827254
0.937209
0.295498
32.696384
std
0.851419
1.023356
0.457449
32.905022
min
0.014810
0.014836
0.000083
0.030472
25%
0.217906
0.246893
0.038760
9.840522
50%
0.529727
0.624512
0.132425
24.348541
75%
1.120389
1.173670
0.354944
48.979925
max
4.983345
8.724252
5.141850
309.855896
In [26]:
df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)][['prediction','target', 'abs_diff','APE']].describe()
Out[26]:
prediction
target
abs_diff
APE
count
2881.000000
2881.000000
2881.000000
2881.000000
mean
0.670118
0.858291
0.383388
57.996346
std
0.757664
1.094979
0.572025
79.790131
min
0.016312
0.014795
0.000374
0.131162
25%
0.158396
0.183774
0.059443
19.831461
50%
0.420433
0.424563
0.154730
43.728092
75%
0.847958
1.033311
0.456153
58.873505
max
4.961184
9.207530
6.232839
1148.485840
In [27]:
df[(df.interchange + df.tile + df.unroll != 0)][['prediction','target', 'abs_diff','APE']].describe()
Out[27]:
prediction
target
abs_diff
APE
count
7926.000000
7926.000000
7926.000000
7926.000000
mean
1.222709
1.543651
0.544563
52.351154
std
1.375149
1.874331
0.786424
164.959198
min
0.014810
0.014795
0.000040
0.009375
25%
0.304365
0.312547
0.065765
12.863225
50%
0.714742
0.830516
0.205677
34.351240
75%
1.788819
2.132191
0.660448
52.750795
max
7.889497
10.872228
6.232839
5399.246094
In [28]:
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 0) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==0) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange==1) & (df.unroll == 1) & (df.tile == 1)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df1 = df[(df.interchange + df.tile + df.unroll != 0)]
joint_plot(df1, f"Validation dataset, {loss_func} loss")
df2 = df
joint_plot(df2, f"Validation dataset, {loss_func} loss")
/data/scratch/henni-mohammed/anaconda3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
In [ ]:
In [ ]:
Content source: rbaghdadi/COLi
Similar notebooks: