Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x


In [ ]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [ ]:
?re.compile

Super resolution data


In [ ]:
from fastai.conv_learner import *
from pathlib import Path
torch.cuda.set_device(0)

torch.backends.cudnn.benchmark=True

In [ ]:
PATH = Path('data/imagenet')
PATH_TRN = PATH/'train'

In [ ]:
fnames_full,label_arr_full,all_labels = folder_source(PATH, 'train')
fnames_full = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full]
list(zip(fnames_full[:5],label_arr_full[:5]))


Out[ ]:
[('n01440764/n01440764_9627.JPEG', 0),
 ('n01440764/n01440764_9609.JPEG', 0),
 ('n01440764/n01440764_5176.JPEG', 0),
 ('n01440764/n01440764_6936.JPEG', 0),
 ('n01440764/n01440764_4005.JPEG', 0)]

In [ ]:
all_labels[:5]


Out[ ]:
['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475']

In [ ]:
np.random.seed(42)
# keep_pct = 1.
keep_pct = 0.02
keeps = np.random.rand(len(fnames_full)) < keep_pct
fnames = np.array(fnames_full, copy=False)[keeps]
label_arr = np.array(label_arr_full, copy=False)[keeps]

In [ ]:
arch = vgg16
sz_lr = 72

In [ ]:
# scale,bs = 2,64
scale,bs = 4,24
sz_hr = sz_lr*scale

In [ ]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [ ]:
aug_tfms = [RandomDihedral(tfm_y=TfmType.PIXEL)]

In [ ]:
val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01/keep_pct, 0.1))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
len(val_x),len(trn_x)


Out[ ]:
(2558, 23026)

In [ ]:
img_fn = PATH/'train'/'n01558993'/'n01558993_9684.JPEG'

In [ ]:
tfms = tfms_from_model(arch, sz_lr, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms, sz_y=sz_hr)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH_TRN)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)

In [ ]:
denorm = md.val_ds.denorm

In [ ]:
def show_img(ims, idx, figsize=(5,5), normed=True, ax=None):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    if normed: ims = denorm(ims)
    else:      ims = np.rollaxis(to_np(ims),1,4)
    ax.imshow(np.clip(ims,0,1)[idx])
    ax.axis('off')

In [ ]:
x,y = next(iter(md.val_dl))
x.size(),y.size()


Out[ ]:
(torch.Size([64, 3, 72, 72]), torch.Size([64, 3, 144, 144]))

In [ ]:
idx=1
fig,axes = plt.subplots(1, 2, figsize=(9,5))
show_img(x,idx, ax=axes[0])
show_img(y,idx, ax=axes[1])



In [ ]:
batches = [next(iter(md.aug_dl)) for i in range(9)]

In [ ]:
fig, axes = plt.subplots(3, 6, figsize=(18, 9))
for i,(x,y) in enumerate(batches):
    show_img(x,idx, ax=axes.flat[i*2])
    show_img(y,idx, ax=axes.flat[i*2+1])


Model


In [ ]:
def conv(ni, nf, kernel_size=3, actn=True):
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
    if actn: layers.append(nn.ReLU(True))
    return nn.Sequential(*layers)

In [ ]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)

    def forward(self, x):
        x = x + self.m(x) * self.res_scale
        return x

In [ ]:
def res_block(nf):
    return ResSequential(
        [conv(nf, nf), conv(nf, nf, actn=False)],
        0.1)

In [ ]:
def upsample(ni, nf, scale):
    layers = []
    for i in range(int(math.log(scale,2))):
        layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
    return nn.Sequential(*layers)

In [ ]:
class SrResnet(nn.Module):
    def __init__(self, nf, scale):
        super().__init__()
        features = [conv(3, 64)]
        for i in range(8): features.append(res_block(64))
        features += [conv(64,64), upsample(64, 64, scale),
                     nn.BatchNorm2d(64),
                     conv(64, 3, actn=False)]
        self.features = nn.Sequential(*features)
        
    def forward(self, x): return self.features(x)

Pixel loss


In [ ]:
m = to_gpu(SrResnet(64, scale))
m = nn.DataParallel(m, [0,2])
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)
learn.crit = F.mse_loss

In [ ]:
learn.lr_find(start_lr=1e-5, end_lr=10000)
learn.sched.plot()


 31%|███▏      | 225/720 [00:24<00:53,  9.19it/s, loss=0.0482]

In [ ]:
lr=2e-3

In [ ]:
learn.fit(lr, 1, cycle_len=1, use_clr_beta=(40,10))


  2%|▏         | 15/720 [00:02<01:52,  6.25it/s, loss=0.042]  
epoch      trn_loss   val_loss                                 
    0      0.007431   0.008192  

Out[ ]:
[array([0.00819])]

In [ ]:
x,y = next(iter(md.val_dl))
preds = learn.model(VV(x))

In [ ]:
idx=4
show_img(y,idx,normed=False)



In [ ]:
show_img(preds,idx,normed=False);



In [ ]:
show_img(x,idx,normed=True);



In [ ]:
x,y = next(iter(md.val_dl))
preds = learn.model(VV(x))

In [ ]:
show_img(y,idx,normed=False)



In [ ]:
show_img(preds,idx,normed=False);



In [ ]:
show_img(x,idx);


Perceptual loss


In [ ]:
def icnr(x, scale=2, init=nn.init.kaiming_normal):
    new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
    subkernel = torch.zeros(new_shape)
    subkernel = init(subkernel)
    subkernel = subkernel.transpose(0, 1)
    subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                            subkernel.shape[1], -1)
    kernel = subkernel.repeat(1, 1, scale ** 2)
    transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
    kernel = kernel.contiguous().view(transposed_shape)
    kernel = kernel.transpose(0, 1)
    return kernel

In [ ]:
m_vgg = vgg16(True)

blocks = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg[i] for i in blocks]


Out[ ]:
([5, 12, 22, 32, 42],
 [ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace)])

In [ ]:
vgg_layers = children(m_vgg)[:13]
m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
set_trainable(m_vgg, False)

In [ ]:
def flatten(x): return x.view(x.size(0), -1)

In [ ]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [ ]:
class FeatureLoss(nn.Module):
    def __init__(self, m, layer_ids, layer_wgts):
        super().__init__()
        self.m,self.wgts = m,layer_wgts
        self.sfs = [SaveFeatures(m[i]) for i in layer_ids]

    def forward(self, input, target, sum_layers=True):
        self.m(VV(target.data))
        res = [F.l1_loss(input,target)/100]
        targ_feat = [V(o.features.data.clone()) for o in self.sfs]
        self.m(input)
        res += [F.l1_loss(flatten(inp.features),flatten(targ))*wgt
               for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
        if sum_layers: res = sum(res)
        return res
    
    def close(self):
        for o in self.sfs: o.remove()

In [ ]:
m = SrResnet(64, scale)

In [ ]:
conv_shuffle = m.features[10][0][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);


/home/jhoward/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
  after removing the cwd from sys.path.

In [ ]:
conv_shuffle = m.features[10][2][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);


/home/jhoward/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
  after removing the cwd from sys.path.

In [ ]:
m = to_gpu(m)

In [ ]:
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)

In [ ]:
t = torch.load(learn.get_model_path('sr-samp0'), map_location=lambda storage, loc: storage)
learn.model.load_state_dict(t, strict=False)

In [ ]:
learn.freeze_to(999)

In [ ]:
for i in range(10,13): set_trainable(m.features[i], True)

In [ ]:
lr=6e-3
wd=1e-7

In [ ]:
learn.fit(lr, 1, cycle_len=1, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                
    0      0.097629   0.091069  

Out[ ]:
[0.09106878512207654]

In [ ]:
learn.crit = FeatureLoss(m_vgg, blocks[:2], [0.26,0.74])

In [ ]:
learn.lr_find(1e-4, 1., wds=wd, linear=True)


 19%|█▊        | 178/960 [01:28<06:28,  2.01it/s, loss=0.141] 
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-36-5feb8d75bbd0> in <module>()
----> 1 learn.lr_find(1e-4, 1., wds=wd, linear=True)

/data1/jhoward/git/fastai/courses/dl2/fastai/learner.py in lr_find(self, start_lr, end_lr, wds, linear, **kwargs)
    328         layer_opt = self.get_layer_opt(start_lr, wds)
    329         self.sched = LR_Finder(layer_opt, len(self.data.trn_dl), end_lr, linear=linear)
--> 330         self.fit_gen(self.model, self.data, layer_opt, 1, **kwargs)
    331         self.load('tmp')
    332 

/data1/jhoward/git/fastai/courses/dl2/fastai/learner.py in fit_gen(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, use_clr_beta, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, use_swa, swa_start, swa_eval_freq, **kwargs)
    232             metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, fp16=self.fp16,
    233             swa_model=self.swa_model if use_swa else None, swa_start=swa_start,
--> 234             swa_eval_freq=swa_eval_freq, **kwargs)
    235 
    236     def get_layer_groups(self): return self.models.get_layer_groups()

/data1/jhoward/git/fastai/courses/dl2/fastai/model.py in fit(model, data, n_epochs, opt, crit, metrics, callbacks, stepper, swa_model, swa_start, swa_eval_freq, **kwargs)
    127             batch_num += 1
    128             for cb in callbacks: cb.on_batch_begin()
--> 129             loss = model_stepper.step(V(x),V(y), epoch)
    130             avg_loss = avg_loss * avg_mom + loss * (1-avg_mom)
    131             debias_loss = avg_loss / (1 - avg_mom**batch_num)

/data1/jhoward/git/fastai/courses/dl2/fastai/model.py in step(self, xs, y, epoch)
     50         if self.fp16: self.m.zero_grad()
     51         else: self.opt.zero_grad()
---> 52         loss = raw_loss = self.crit(output, y)
     53         if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale
     54         if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    489             result = self._slow_forward(*input, **kwargs)
    490         else:
--> 491             result = self.forward(*input, **kwargs)
    492         for hook in self._forward_hooks.values():
    493             hook_result = hook(self, input, result)

<ipython-input-26-e111310935a1> in forward(self, input, target, sum_layers)
     11         self.m(input)
     12         res += [F.l1_loss(flatten(inp.features),flatten(targ))*wgt
---> 13                for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
     14         if sum_layers: res = sum(res)
     15         return res

<ipython-input-26-e111310935a1> in <listcomp>(.0)
     11         self.m(input)
     12         res += [F.l1_loss(flatten(inp.features),flatten(targ))*wgt
---> 13                for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
     14         if sum_layers: res = sum(res)
     15         return res

~/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in l1_loss(input, target, size_average, reduce)
   1556     """
   1557     return _pointwise_loss(lambda a, b: torch.abs(a - b), torch._C._nn.l1_loss,
-> 1558                            input, target, size_average, reduce)
   1559 
   1560 

~/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in _pointwise_loss(lambd, lambd_optimized, input, target, size_average, reduce)
   1535         return torch.mean(d) if size_average else torch.sum(d)
   1536     else:
-> 1537         return lambd_optimized(input, target, size_average, reduce)
   1538 
   1539 

KeyboardInterrupt: 

In [ ]:
learn.load('tmp')

In [ ]:
learn.sched.plot(0, n_skip_end=1)



In [ ]:
learn.save('sr-samp0')

In [ ]:
learn.unfreeze()

In [ ]:
learn.fit(lr, 1, cycle_len=1, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                
    0      0.063747   0.060955  

Out[ ]:
[0.060954763361072986]

In [ ]:
learn.fit(lr, 1, cycle_len=2, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                
    0      0.059319   0.058087  
    1      0.057485   0.056862                                

Out[ ]:
[0.05686224508299503]

In [ ]:
learn.fit(lr, 1, cycle_len=2, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                
    0      0.066028   0.064855  
    1      0.063048   0.062271                                

Out[ ]:
[0.06227088433583329]

In [ ]:
learn.sched.plot_loss()



In [ ]:
learn.load('sr-samp1')

In [ ]:
learn.save('sr-samp1')

In [ ]:
learn.load('sr-samp1')

In [ ]:
lr=3e-3

In [ ]:
learn.fit(lr, 1, cycle_len=1, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                
    0      0.069054   0.06638   

Out[ ]:
[array([0.06638])]

In [ ]:
learn.save('sr-samp2')

In [ ]:
learn.unfreeze()

In [ ]:
learn.load('sr-samp2')

In [ ]:
learn.fit(lr/3, 1, cycle_len=1, wds=wd, use_clr=(20,10))


epoch      trn_loss   val_loss                                        
    0      0.06042    0.057613  

Out[ ]:
[array([0.05761])]

In [ ]:
learn.save('sr1')

In [ ]:
def plot_ds_img(idx, ax=None, figsize=(7,7), normed=True):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    im = md.val_ds[idx][0]
    if normed: im = denorm(im)[0]
    else:      im = np.rollaxis(to_np(im),0,3)
    ax.imshow(im)
    ax.axis('off')

In [ ]:
fig,axes=plt.subplots(6,6,figsize=(20,20))
for i,ax in enumerate(axes.flat): plot_ds_img(i+200,ax=ax, normed=True)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

In [ ]:
x,y=md.val_ds[201]

In [ ]:
y=y[None]

In [ ]:
learn.model.eval()
preds = learn.model(VV(x[None]))
x.shape,y.shape,preds.shape


Out[ ]:
((3, 72, 72), (1, 3, 288, 288), torch.Size([1, 3, 288, 288]))

In [ ]:
learn.crit(preds, V(y), sum_layers=False)


Out[ ]:
[tensor(1.00000e-03 *
        2.0065, device='cuda:0'), tensor(1.00000e-02 *
        2.2847, device='cuda:0'), tensor(1.00000e-02 *
        4.8297, device='cuda:0')]

In [ ]:
learn.crit(preds, V(y), sum_layers=False)


Out[ ]:
[tensor(1.00000e-03 *
        2.1613, device='cuda:0'), tensor(1.00000e-03 *
        4.7700, device='cuda:0'), tensor(1.00000e-02 *
        4.4875, device='cuda:0'), tensor(1.00000e-02 *
        1.0256, device='cuda:0')]

In [ ]:
learn.crit.close()

In [ ]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])



In [ ]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])



In [ ]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])



In [ ]: