Practical Deep Learning for Coders, v3

Lesson7_superres_gan

Pretrained GAN

预训练生成对抗网络


In [ ]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *

In [ ]:
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'

Crappified data 噪化数据

Prepare the input data by crappifying images.

使用噪化的图片准备输入数据


In [ ]:
from crappify import *

Uncomment the first time you run this notebook.

如果首次运行这个notebook,请把注释打开


In [ ]:
#il = ImageList.from_folder(path_hr)
#parallel(crappifier(path_lr, path_hr), il.items)

For gradual resizing we can change the commented line here.

如果需要对图片进行渐变修改或者缩放尺寸请打开下面的注释


In [ ]:
bs,size=32, 128
# bs,size = 24,160
#bs,size = 8,256
arch = models.resnet34

Pre-train generator 预训练数据生成器

Now let's pretrain the generator.

现在我们来预训练数据生成器


In [ ]:
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)

In [ ]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [ ]:
data_gen = get_data(bs,size)

In [ ]:
data_gen.show_batch(4)



In [ ]:
wd = 1e-3

In [ ]:
y_range = (-3.,3.)

In [ ]:
loss_gen = MSELossFlat()

In [ ]:
def create_gen_learner():
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen)

In [ ]:
learn_gen = create_gen_learner()

In [ ]:
learn_gen.fit_one_cycle(2, pct_start=0.8)


Total time: 01:35

epoch train_loss valid_loss
1 0.061653 0.053493
2 0.051248 0.047272


In [ ]:
learn_gen.unfreeze()

In [ ]:
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))


Total time: 02:24

epoch train_loss valid_loss
1 0.050429 0.046088
2 0.049056 0.043954
3 0.045437 0.043146


In [ ]:
learn_gen.show_results(rows=4)



In [ ]:
learn_gen.save('gen-pre2')

Save generated images 保存生成的图片


In [ ]:
learn_gen.load('gen-pre2');

In [ ]:
name_gen = 'image_gen'
path_gen = path/name_gen

In [ ]:
# shutil.rmtree(path_gen)

In [ ]:
path_gen.mkdir(exist_ok=True)

In [ ]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

In [ ]:
save_preds(data_gen.fix_dl)

In [ ]:
PIL.Image.open(path_gen.ls()[0])


Out[ ]:

Train critic 训练评价网络


In [ ]:
learn_gen=None
gc.collect()


Out[ ]:
3755

Pretrain the critic on crappy vs not crappy.

使用噪化和非噪化数据来预训练批评者网络


In [ ]:
def get_crit_data(classes, bs, size):
    src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=size)
           .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data

In [ ]:
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)

In [ ]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)



In [ ]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [ ]:
def create_critic_learner(data, metrics):
    return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)

In [ ]:
learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)

In [ ]:
learn_critic.fit_one_cycle(6, 1e-3)


Total time: 09:40

epoch train_loss valid_loss accuracy_thresh_expand
1 0.678256 0.687312 0.531083
2 0.434768 0.366180 0.851823
3 0.186435 0.128874 0.955214
4 0.120681 0.072901 0.980228
5 0.099568 0.107304 0.962564
6 0.071958 0.078094 0.976239


In [ ]:
learn_critic.save('critic-pre2')

GAN

Now we'll combine those pretrained model in a GAN.

现在我们将这些经过预训练的模型整合到GAN里。


In [ ]:
learn_crit=None
learn_gen=None
gc.collect()


Out[ ]:
15794

In [ ]:
data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)

In [ ]:
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')

In [ ]:
learn_gen = create_gen_learner().load('gen-pre2')

To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.

为了定义一个GAN学习器,我们需要为生成器和评价网络指定学习组件对象。switcher(切换器)是一个回调(装置),这个装置决定何时在鉴别器和生成器之间转换。这里我们针对鉴别器训练足够多的iterations(迭代次数)直到其损失函数值小于0.5,然后再对生成器训练一次。

The loss of the critic is given by learn_crit.loss_func. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0).

评价网络的损失由learn_crit.loss_func获得。基于正预测值(目标为1)和负预测值(目标为0)的迭代训练会针对这个损失函数取均值。

The loss of the generator is weighted sum (weights in weights_gen) of learn_crit.loss_func on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the learn_gen.loss_func applied to the output (batch of fake) and the target (corresponding batch of superres images).

生成器网络的损失则是learn_crit.loss_funclearn_gen.loss_func的加权和(权重在weights_gen中),(其中)learn_crit.loss_func是基于目标为1的fake(失败的)批次(通过评价网络成为预测值),learn_gen.loss_func则应用到fake批次的输出及目标上(对应一批超分辨率图像)。


In [ ]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [ ]:
lr = 1e-4

In [ ]:
learn.fit(40,lr)


Total time: 1:05:41

epoch train_loss gen_loss disc_loss
1 2.071352 2.025429 4.047686
2 1.996251 1.850199 3.652173
3 2.001999 2.035176 3.612669
4 1.921844 1.931835 3.600355
5 1.987216 1.961323 3.606629
6 2.022372 2.102732 3.609494
7 1.900056 2.059208 3.581742
8 1.942305 1.965547 3.538015
9 1.954079 2.006257 3.593008
10 1.984677 1.771790 3.617556
11 2.040979 2.079904 3.575464
12 2.009052 1.739175 3.626755
13 2.014115 1.204614 3.582353
14 2.042148 1.747239 3.608723
15 2.113957 1.831483 3.684338
16 1.979398 1.923163 3.600483
17 1.996756 1.760739 3.635300
18 1.976695 1.982629 3.575843
19 2.088960 1.822936 3.617471
20 1.949941 1.996513 3.594223
21 2.079416 1.918284 3.588732
22 2.055047 1.869254 3.602390
23 1.860164 1.917518 3.557776
24 1.945440 2.033273 3.535242
25 2.026493 1.804196 3.558001
26 1.875208 1.797288 3.511697
27 1.972286 1.798044 3.570746
28 1.950635 1.951106 3.525849
29 2.013820 1.937439 3.592216
30 1.959477 1.959566 3.561970
31 2.012466 2.110288 3.539897
32 1.982466 1.905378 3.559940
33 1.957023 2.207354 3.540873
34 2.049188 1.942845 3.638360
35 1.913136 1.891638 3.581291
36 2.037127 1.808180 3.572567
37 2.006383 2.048738 3.553226
38 2.000312 1.657985 3.594805
39 1.973937 1.891186 3.533843
40 2.002513 1.853988 3.554688


In [ ]:
learn.save('gan-1c')

In [ ]:
learn.data=get_data(16,192)

In [ ]:
learn.fit(10,lr/2)


Total time: 43:07

epoch train_loss gen_loss disc_loss
1 2.578580 2.415008 4.716179
2 2.620808 2.487282 4.729377
3 2.596190 2.579693 4.796489
4 2.701113 2.522197 4.821410
5 2.545030 2.401921 4.710739
6 2.638539 2.548171 4.776103
7 2.551988 2.513859 4.644952
8 2.629724 2.490307 4.701890
9 2.552170 2.487726 4.728183
10 2.597136 2.478334 4.649708


In [ ]:
learn.show_results(rows=16)



In [ ]:
learn.save('gan-1c')