In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [40, 30]

In [4]:
import fastai as fai
from fastai.basic_data import DataLoader
from data_loader import *
import numpy as np
from torch.utils.data import SubsetRandomSampler
from model import *
from model_bn import *
from torch import optim
import dill
import papermill as pm
from main import *
import pandas as pd

In [5]:
optimizer = 'Adam'
num_workers=8
maxsize=100000
batch_size=2048
n_epochs=500
batch_norm = True
dataset='data/speedup_dataset.h5'

In [6]:
train_dl, val_dl = train_dev_split(dataset, batch_size, num_workers, maxsize)

db = fai.basic_data.DataBunch(train_dl, val_dl)

In [7]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]


model = None 

if batch_norm:
    model = Model_BN(input_size, output_size)
else:
    model = Model(input_size, output_size)
    
criterion = nn.MSELoss()

l = fai.Learner(db, model, loss_func=criterion)

if optimizer == 'SGD':
    l.opt_func = optim.SGD

In [ ]:


In [12]:
l.lr_find()


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

In [13]:
l.recorder.plot()



In [14]:
lr = 0.001

In [15]:
l.fit_one_cycle(300, lr)


Total time: 1:00:42

epoch train_loss valid_loss
1 0.962619 0.915109
2 0.954268 0.922069
3 0.948310 0.933422
4 0.959162 0.927936
5 0.955292 0.962480
6 0.955686 0.983531
7 0.959359 0.964114
8 0.963162 0.961883
9 0.959086 1.133870
10 0.958194 0.910141
11 0.963431 1.181516
12 0.958487 1.101186
13 0.953688 0.974195
14 0.950342 1.079128
15 0.953169 1.011540
16 0.958791 0.911496
17 0.965622 0.916160
18 0.958196 0.910409
19 0.956805 1.723271
20 0.962710 1.349578
21 0.975255 1.166988
22 0.969709 1.244711
23 0.962910 1.445498
24 0.961744 0.957583
25 0.959938 0.907568
26 0.960245 2.216190
27 0.957248 1.041014
28 0.966550 0.912505
29 0.969333 1.413470
30 0.979346 1.023455
31 0.969207 0.910495
32 0.966716 0.926144
33 0.971191 0.972263
34 0.982054 1.791826
35 0.977938 2.067481
36 0.975192 1.136394
37 0.968994 8.437294
38 0.968171 1.465559
39 0.977725 1.668962
40 0.980343 2.336482
41 0.976578 1.405700
42 0.972189 1.386190
43 0.994605 1.139041
44 0.996136 1.564029
45 1.001471 1.024466
46 0.988763 1.724174
47 0.979455 3.846130
48 0.975557 0.946329
49 0.979129 1.139909
50 0.985777 10.688543
51 0.987837 1.504450
52 0.997345 3.349200
53 0.986095 10.529685
54 0.988243 2.349234
55 0.989963 2.245249
56 0.995621 5.442734
57 1.011026 3.887688
58 1.033823 4.022326
59 1.032340 1.306832
60 1.012564 1.303052
61 1.031588 0.956364
62 1.002397 0.976605
63 1.009010 1.578107
64 1.011429 0.943647
65 0.992546 1.566283
66 0.993822 1.148076
67 1.013706 7.031891
68 0.989051 4.953737
69 1.006781 5.687341
70 1.027610 2.792975
71 1.026081 17.322279
72 1.035339 3.567602
73 1.032049 2.022600
74 1.027075 4.051038
75 1.023858 49.291580
76 1.032621 1.204014
77 1.028244 7.399761
78 1.019167 50.745552
79 1.008186 5.084015
80 1.019345 1.932068
81 1.018064 2.609910
82 1.025124 6.783131
83 1.016703 0.976507
84 1.042682 11.422091
85 1.037994 2.099921
86 1.044919 7.243030
87 1.029688 4.065248
88 1.018333 8.078403
89 1.012945 2.534340
90 0.995931 24.992874
91 1.030319 7.141800
92 1.041096 23.549728
93 1.030407 1.568124
94 1.032346 4.956819
95 1.041503 2.741663
96 1.029881 47.880741
97 1.018023 3.624795
98 1.018865 4.257344
99 1.022362 1.183182
100 1.016897 11.693587
101 1.023401 1.013481
102 1.010706 1.058470
103 1.017431 4.283300
104 1.011519 2.147990
105 1.016998 2.752353
106 1.008513 1.089370
107 1.029516 1.037856
108 1.063280 8.194583
109 1.020817 1.175390
110 1.026166 66.679100
111 1.020615 12.056190
112 1.018855 6.898706
113 1.009621 1.873381
114 1.024827 5.652112
115 1.014743 13.379063
116 1.024454 4.314157
117 1.020448 4.443440
118 1.035727 5.008492
119 1.016518 1.879692
120 1.004506 21.458492
121 1.030957 1.174026
122 1.028610 9.071132
123 1.020107 3.072988
124 1.067836 3.692532
125 1.052181 38.784077
126 1.032354 0.984219
127 1.006042 1.147598
128 1.018080 7.377429
129 1.016079 1.180985
130 1.011210 3.132229
131 1.005414 5.990058
132 1.011475 5.407722
133 1.022535 51.026608
134 1.013938 0.926936
135 1.036134 1.930930
136 1.020005 1.173658
137 1.001774 1.694784
138 1.003198 20.826010
139 0.989681 7.570766
140 0.992044 3.881327
141 1.019274 6.227853
142 1.012186 0.922485
143 1.012391 6.232856
144 1.007821 1.096579
145 1.001016 1.826149
146 1.008592 5.150884
147 0.992744 5.799298
148 1.004445 3.369620
149 1.004207 1.836578
150 0.998920 3.661187
151 0.991942 7.021226
152 0.997703 0.950145
153 0.998807 0.922997
154 0.988512 1.675743
155 0.983039 4.772948
156 0.996659 1.566359
157 0.998596 2.294856
158 1.024160 15.951485
159 1.016937 3.243188
160 0.993930 0.903795
161 0.983244 0.991699
162 0.970437 0.917454
163 0.990595 7.656559
164 0.985155 0.902289
165 0.978467 1.008783
166 0.981866 8.444477
167 1.013059 6.265702
168 0.992509 1.209633
169 0.979364 3.008283
170 0.976698 0.983922
171 0.979335 1.522205
172 0.977496 3.761003
173 0.970319 12.244261
174 0.985421 2.432755
175 0.985892 3.991267
176 0.982820 1.941380
177 0.976472 4.565709
178 0.986833 1.340991
179 0.979794 0.970777
180 0.976931 2.705573
181 0.981839 6.643618
182 0.974869 1.704249
183 0.971486 1.598377
184 0.980343 44.733845
185 0.979107 3.235343
186 0.976853 1.362298
187 0.967989 0.948697
188 0.970369 1.186510
189 0.963412 1.022058
190 0.959873 2.418851
191 0.980325 1.080040
192 0.965825 1.013413
193 0.978723 11.137742
194 0.973777 3.935230
195 0.970769 2.691924
196 0.971074 2.546328
197 0.977482 3.642085
198 0.970317 16.227551
199 0.967983 2.915864
200 0.959402 1.162891
201 0.953550 1.215299
202 0.954696 1.081202
203 0.958483 3.561959
204 0.977614 4.220924
205 0.971425 3.889741
206 0.960017 1.019424
207 0.965003 1.244552
208 0.955041 0.932822
209 0.944537 1.019393
210 0.956985 0.905125
211 0.959666 1.014233
212 0.959652 1.687990
213 0.956034 2.396130
214 0.941874 0.981280
215 0.946802 1.187359
216 0.947629 4.160324
217 0.941372 2.054866
218 0.938585 3.648737
219 0.949912 1.274899
220 0.945518 2.892151
221 0.944726 1.662836
222 0.957335 1.264481
223 0.965665 2.605370
224 0.952171 1.120397
225 0.946867 0.894802
226 0.943715 3.698564
227 0.935892 1.669343
228 0.936448 1.966962
229 0.934748 1.043284
230 0.937146 1.551980
231 0.936010 0.970134
232 0.939646 1.098396
233 0.944255 1.143497
234 0.942528 0.996397
235 0.940038 0.911583
236 0.940685 1.073761
237 0.933694 1.497303
238 0.932743 0.897990
239 0.932732 1.112728
240 0.932047 0.949108
241 0.938022 1.651002
242 0.930220 1.893481
243 0.927720 0.961546
244 0.931749 1.094855
245 0.925514 1.237468
246 0.923974 1.132028
247 0.934754 0.886416
248 0.934846 0.886135
249 0.930747 0.880054
250 0.929142 1.177677
251 0.929388 0.880279
252 0.924301 1.501957
253 0.925970 1.419163
254 0.924293 1.296368
255 0.918227 0.885576
256 0.924943 0.914575
257 0.927457 1.291849
258 0.921260 1.211308
259 0.919223 0.893055
260 0.924816 0.949281
261 0.925698 0.937083
262 0.922957 0.909530
263 0.920958 0.927418
264 0.915863 0.923966
265 0.916177 0.938333
266 0.916985 0.935001
267 0.923018 0.923977
268 0.918734 0.927690
269 0.921372 0.881170
270 0.916836 0.878169
271 0.919865 0.905652
272 0.917705 0.945532
273 0.911727 0.890095
274 0.918045 0.909807
275 0.915745 0.963636
276 0.915717 0.905782
277 0.916208 0.874665
278 0.917313 0.876363
279 0.918392 0.873896
280 0.916550 0.874677
281 0.918100 0.877356
282 0.916956 0.915776
283 0.918556 0.896287
284 0.915844 0.879383
285 0.912546 0.874186
286 0.915408 0.880964
287 0.925465 0.886725
288 0.913104 0.869396
289 0.911308 0.879567
290 0.917529 0.875883
291 0.915668 0.879757
292 0.916806 0.880244
293 0.911883 0.877276
294 0.914535 0.874609
295 0.916008 0.880023
296 0.912126 0.872908
297 0.913582 0.872073
298 0.915000 0.877388
299 0.911661 0.874460
300 0.918898 0.877000


In [18]:
l.recorder.plot_losses()



In [17]:
l.save(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse")

In [8]:
l = l.load(f"speedup_{optimizer}_batch_norm_{batch_norm}_mse")

In [9]:
val_df = pd.DataFrame()
train_df = pd.DataFrame()

preds, targets = l.get_preds(fai.basic_data.DatasetType.Valid)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

val_df['pred'] = preds
val_df['target'] = targets
val_df['abs_diff'] = np.abs(preds - targets)
val_df['APE'] = np.abs(val_df.target - val_df.pred)/val_df.target * 100

preds, targets = l.get_preds(fai.basic_data.DatasetType.Train)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

train_df['pred'] = preds
train_df['target'] = targets
train_df['abs_diff'] = np.abs(preds - targets)
train_df['APE'] = np.abs(train_df.target - train_df.pred)/train_df.target * 100

In [19]:
val_df.abs_diff.describe()


Out[19]:
count    10000.000000
mean         0.650753
std          0.695865
min          0.000062
25%          0.201365
50%          0.428166
75%          0.836580
max          5.915672
Name: abs_diff, dtype: float64

In [11]:
val_df.APE.describe()


Out[11]:
count    10000.000000
mean        75.867996
std        140.273727
min          0.008468
25%         14.071760
50%         31.454399
75%         65.824497
max       1650.147095
Name: APE, dtype: float64

In [ ]: