In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [40, 30]

In [2]:
import fastai as fai
from fastai.basic_data import DataLoader
from data_loader import *
import numpy as np
from torch.utils.data import SubsetRandomSampler
from model import *
from model_bn import *
from torch import optim
import dill
import papermill as pm
from main import *
import pandas as pd

In [3]:
optimizer = 'Adam'
num_workers=8
maxsize=100000
batch_size=2048
n_epochs=500
batch_norm = True
dataset='data/speedup_dataset.h5'

In [4]:
train_dl, val_dl = train_dev_split(dataset, batch_size, num_workers, maxsize)

db = fai.basic_data.DataBunch(train_dl, val_dl)

In [5]:
def criterion(inputs, targets):
    eps = 1e-5
    return torch.mean(torch.abs(targets - inputs)/(targets+eps)*100)

In [6]:
input_size = train_dl.dataset.X.shape[1]
output_size = train_dl.dataset.Y.shape[1]


model = None 

if batch_norm:
    model = Model_BN(input_size, output_size)
else:
    model = Model(input_size, output_size)
    
#criterion = nn.MSELoss()

l = fai.Learner(db, model, loss_func=criterion)

if optimizer == 'SGD':
    l.opt_func = optim.SGD

In [ ]:


In [11]:
l.lr_find()


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

In [12]:
l.recorder.plot()



In [13]:
lr = 0.001

In [14]:
l.fit_one_cycle(300, lr)


Total time: 58:57

epoch train_loss valid_loss
1 69.968178 85.605179
2 66.441895 56.247520
3 63.386555 53.371758
4 61.322170 53.186852
5 59.668354 54.909576
6 58.343056 51.773800
7 57.436409 53.187424
8 56.559261 52.113018
9 55.698997 51.065475
10 54.896133 50.632309
11 54.120777 50.396114
12 53.471691 55.479332
13 52.964935 50.489120
14 52.590603 56.709175
15 52.008270 52.234600
16 51.678749 49.410648
17 51.220985 50.528561
18 50.633717 49.216053
19 50.326653 48.581383
20 49.776733 51.501286
21 49.573425 57.175831
22 49.239239 67.229240
23 48.808891 47.446739
24 48.559986 48.970848
25 48.245995 52.241360
26 47.847141 49.472282
27 47.764729 82.360031
28 47.558674 53.046219
29 48.007198 179.849869
30 47.347702 48.772377
31 47.280945 64.661270
32 47.030365 45.761192
33 46.788437 46.185940
34 46.595737 59.126167
35 46.556393 46.900826
36 46.433136 51.536392
37 46.438725 63.887794
38 46.258530 45.753498
39 46.261593 47.337029
40 46.323109 108.877609
41 46.209930 141.157669
42 46.046841 44.761505
43 46.617561 61.543568
44 46.081638 54.706730
45 47.628387 46.929558
46 46.640972 48.314655
47 46.299629 70.453758
48 45.986027 44.670921
49 46.482727 46.546963
50 45.973667 52.853989
51 45.736507 73.744385
52 45.829155 80.690575
53 45.733112 63.749229
54 45.637291 56.247593
55 47.172070 230.157898
56 47.125824 53.439857
57 46.332115 64.605003
58 45.870911 128.964035
59 45.923737 202.876282
60 46.400749 55.983074
61 46.364990 77.465965
62 45.844322 140.040192
63 45.486038 79.547211
64 45.626949 77.189774
65 46.091488 49.767208
66 46.009598 44.021000
67 46.339035 105.313713
68 45.923981 127.137627
69 45.607586 664.445374
70 45.913147 78.290306
71 46.141354 314.453583
72 46.058453 72.510872
73 45.933182 61.804951
74 45.995758 44.509644
75 46.156147 92.153503
76 45.804726 76.441475
77 45.420963 62.312950
78 45.656990 54.659473
79 45.569992 69.906754
80 46.072536 83.193405
81 45.710159 48.805492
82 46.182484 82.847313
83 46.057045 62.508251
84 46.560955 61.487556
85 46.295887 166.097107
86 45.760628 76.496635
87 45.422207 44.069458
88 45.881763 74.982391
89 46.161362 123.555588
90 45.561684 75.112152
91 45.947411 317.677765
92 45.616772 72.639008
93 45.264297 484.107391
94 45.921516 44.731308
95 45.506981 183.908432
96 45.347134 56.058983
97 45.014538 68.197433
98 45.015842 45.678055
99 45.561119 65.688583
100 45.731907 81.568108
101 46.166634 50.077755
102 45.542789 47.815662
103 45.837372 362.516449
104 45.623901 76.072815
105 45.883686 706.319031
106 45.076077 45.572281
107 45.230499 76.141685
108 45.039795 47.177864
109 45.515213 56.159325
110 45.371536 44.558876
111 45.119331 281.675140
112 44.858742 43.047749
113 45.237099 521.570801
114 45.638920 45.468723
115 45.211163 691.752441
116 45.887669 70.913971
117 45.097851 43.584148
118 44.898434 59.860626
119 44.858173 81.004753
120 45.417027 63.619705
121 45.002289 173.215805
122 45.005154 165.290726
123 44.938831 83.054672
124 45.794746 83.589378
125 45.514317 78.801308
126 45.664505 131.904755
127 45.551285 128.989868
128 45.925030 78.523331
129 45.718021 52.055511
130 45.122078 268.819763
131 45.372917 45.346470
132 45.029446 56.117401
133 46.564419 60.643364
134 46.231457 247.339767
135 45.387939 44.245468
136 44.843567 43.226936
137 44.941032 462.750458
138 45.521030 83.131981
139 45.221558 344.910736
140 45.185585 298.745300
141 45.185505 423.242462
142 44.626343 43.734356
143 44.578930 70.428459
144 44.949478 210.566025
145 44.737511 53.618225
146 44.551914 70.233727
147 45.051224 64.489563
148 44.407108 57.380836
149 44.633167 79.551720
150 45.277416 51.411037
151 44.979164 44.168293
152 45.037834 51.855350
153 44.778156 64.013885
154 44.579681 269.929718
155 44.245132 61.638855
156 44.403698 296.052307
157 44.922981 44.376209
158 45.046207 132.338806
159 44.506092 54.022224
160 44.149605 59.519024
161 44.540977 82.810959
162 44.293770 615.111511
163 44.385010 80.757660
164 44.321644 341.056671
165 44.335121 69.914803
166 45.454033 70.508797
167 44.678867 118.014961
168 44.345470 54.882648
169 44.315010 71.954399
170 44.069134 64.280273
171 44.320099 77.238304
172 44.234745 241.955154
173 44.602329 54.271324
174 44.434906 47.870224
175 43.997417 67.296761
176 43.752609 376.097687
177 43.751518 80.968147
178 44.167309 54.609852
179 44.133091 42.689808
180 43.957283 59.117462
181 43.889183 45.617088
182 44.884327 44.282856
183 44.234188 68.765038
184 44.165756 45.361481
185 44.030537 44.614002
186 44.031403 62.738861
187 43.829334 53.558067
188 43.922924 62.118969
189 43.857056 98.703583
190 43.991394 42.585567
191 43.812874 97.339615
192 43.645409 66.502167
193 43.826717 248.689301
194 44.005714 78.019653
195 43.811211 53.676456
196 43.814163 55.039139
197 43.870312 66.544151
198 43.960567 144.137146
199 43.591049 61.640182
200 43.542015 73.223457
201 43.722115 87.237549
202 43.492687 98.536163
203 43.221046 58.342850
204 43.345512 43.850941
205 43.814209 116.554802
206 43.845581 45.078548
207 43.537483 44.132072
208 44.245789 42.652512
209 44.479549 77.638184
210 44.201580 43.516598
211 43.827499 59.204342
212 43.713558 44.036041
213 43.647179 49.014725
214 43.805790 138.839523
215 43.852734 61.745850
216 43.538132 48.903893
217 43.535225 48.714443
218 43.426926 43.083370
219 43.547295 89.383293
220 43.304459 68.985947
221 43.232632 44.613724
222 43.151043 42.080040
223 43.110481 42.199398
224 43.446877 65.677696
225 43.336403 87.559372
226 43.311462 45.094360
227 42.960167 100.345703
228 42.975746 42.703308
229 42.996170 45.398914
230 43.055256 42.504745
231 43.086430 43.644585
232 43.036507 47.098946
233 43.189690 51.669693
234 43.240742 44.833618
235 43.402519 42.711987
236 43.025017 73.776253
237 43.049332 50.435688
238 43.019424 49.235184
239 42.933365 41.959305
240 42.952263 51.807728
241 43.050110 51.029846
242 42.854607 43.730614
243 42.958244 42.681091
244 42.861233 61.606056
245 42.896477 47.389164
246 43.291500 45.773640
247 43.251610 48.318970
248 43.165619 42.980186
249 43.408447 55.606617
250 43.070183 63.832920
251 42.935570 44.002399
252 42.912983 42.168449
253 42.831703 42.183586
254 42.791317 41.914383
255 42.895344 42.690014
256 42.860626 52.275894
257 42.868362 41.927471
258 42.912769 71.980934
259 42.769089 43.168133
260 42.809105 45.619564
261 42.876507 41.959141
262 42.934387 45.517128
263 42.760799 42.484528
264 42.730930 42.517395
265 42.728622 41.865868
266 42.779972 43.741795
267 42.740959 42.473969
268 42.686111 41.801018
269 42.690235 41.761070
270 42.673679 42.316708
271 42.720142 42.669167
272 42.729576 48.029236
273 42.730125 43.212505
274 42.721119 41.975224
275 42.683792 42.318611
276 42.621040 42.501965
277 42.586929 41.919006
278 42.628792 41.805351
279 42.708797 42.485100
280 42.710697 41.783192
281 42.682159 41.873112
282 42.669506 42.468811
283 42.783146 41.897892
284 42.873787 42.150120
285 42.718662 41.826279
286 42.656605 41.861645
287 42.654648 41.876064
288 42.642971 41.942310
289 42.664837 42.252117
290 42.830624 41.791561
291 42.674290 41.816559
292 42.698174 41.892845
293 42.637264 41.752548
294 42.603455 41.759281
295 42.691261 41.756012
296 42.663673 41.885921
297 42.613098 41.815525
298 42.579929 41.866638
299 42.608913 41.780209
300 42.676521 41.894428


In [17]:
l.recorder.plot_losses()



In [16]:
l.save(f"speedup_{optimizer}_batch_norm_{batch_norm}_mape")

In [7]:
l = l.load(f"speedup_{optimizer}_batch_norm_{batch_norm}_mape")

In [8]:
val_df = pd.DataFrame()
train_df = pd.DataFrame()

preds, targets = l.get_preds(fai.basic_data.DatasetType.Valid)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

val_df['pred'] = preds
val_df['target'] = targets
val_df['abs_diff'] = np.abs(preds - targets)
val_df['APE'] = np.abs(val_df.target - val_df.pred)/val_df.target * 100

preds, targets = l.get_preds(fai.basic_data.DatasetType.Train)

preds = preds.reshape((-1,)).numpy()
targets = targets.reshape((-1,)).numpy()

train_df['pred'] = preds
train_df['target'] = targets
train_df['abs_diff'] = np.abs(preds - targets)
train_df['APE'] = np.abs(train_df.target - train_df.pred)/train_df.target * 100

In [9]:
val_df.abs_diff.describe()


Out[9]:
count    10000.000000
mean         0.835158
std          1.139942
min          0.000013
25%          0.123175
50%          0.359445
75%          1.040493
max         10.972778
Name: abs_diff, dtype: float64

In [10]:
val_df.APE.describe()


Out[10]:
count    10000.000000
mean        39.477051
std         32.403385
min          0.001978
25%         15.713947
50%         34.318180
75%         56.991356
max        437.290344
Name: APE, dtype: float64

In [ ]: