Pretrained GAN


In [ ]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *

In [ ]:
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'

Crappified data

Prepare the input data by crappifying images.


In [ ]:
from crappify import *

Uncomment the first time you run this notebook.


In [ ]:
#il = ImageList.from_folder(path_hr)
#parallel(crappifier(path_lr, path_hr), il.items)

For gradual resizing we can change the commented line here.


In [ ]:
bs,size=32, 128
# bs,size = 24,160
#bs,size = 8,256
arch = models.resnet34

Pre-train generator

Now let's pretrain the generator.


In [ ]:
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)

In [ ]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [ ]:
data_gen = get_data(bs,size)

In [ ]:
data_gen.show_batch(4)



In [ ]:
wd = 1e-3

In [ ]:
y_range = (-3.,3.)

In [ ]:
loss_gen = MSELossFlat()

In [ ]:
def create_gen_learner():
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen)

In [ ]:
learn_gen = create_gen_learner()

In [ ]:
learn_gen.fit_one_cycle(2, pct_start=0.8)


Total time: 01:35

epoch train_loss valid_loss
1 0.061653 0.053493
2 0.051248 0.047272


In [ ]:
learn_gen.unfreeze()

In [ ]:
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))


Total time: 02:24

epoch train_loss valid_loss
1 0.050429 0.046088
2 0.049056 0.043954
3 0.045437 0.043146


In [ ]:
learn_gen.show_results(rows=4)



In [ ]:
learn_gen.save('gen-pre2')

Save generated images


In [ ]:
learn_gen.load('gen-pre2');

In [ ]:
name_gen = 'image_gen'
path_gen = path/name_gen

In [ ]:
# shutil.rmtree(path_gen)

In [ ]:
path_gen.mkdir(exist_ok=True)

In [ ]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

In [ ]:
save_preds(data_gen.fix_dl)

In [ ]:
PIL.Image.open(path_gen.ls()[0])


Out[ ]:

Train critic


In [ ]:
learn_gen=None
gc.collect()


Out[ ]:
3755

Pretrain the critic on crappy vs not crappy.


In [ ]:
def get_crit_data(classes, bs, size):
    src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=size)
           .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data

In [ ]:
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)

In [ ]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)



In [ ]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [ ]:
def create_critic_learner(data, metrics):
    return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)

In [ ]:
learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)

In [ ]:
learn_critic.fit_one_cycle(6, 1e-3)


Total time: 09:40

epoch train_loss valid_loss accuracy_thresh_expand
1 0.678256 0.687312 0.531083
2 0.434768 0.366180 0.851823
3 0.186435 0.128874 0.955214
4 0.120681 0.072901 0.980228
5 0.099568 0.107304 0.962564
6 0.071958 0.078094 0.976239


In [ ]:
learn_critic.save('critic-pre2')

GAN

Now we'll combine those pretrained model in a GAN.


In [ ]:
learn_crit=None
learn_gen=None
gc.collect()


Out[ ]:
15794

In [ ]:
data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)

In [ ]:
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')

In [ ]:
learn_gen = create_gen_learner().load('gen-pre2')

To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.

The loss of the critic is given by learn_crit.loss_func. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0).

The loss of the generator is weighted sum (weights in weights_gen) of learn_crit.loss_func on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the learn_gen.loss_func applied to the output (batch of fake) and the target (corresponding batch of superres images).


In [ ]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [ ]:
lr = 1e-4

In [ ]:
learn.fit(40,lr)


Total time: 1:05:41

epoch train_loss gen_loss disc_loss
1 2.071352 2.025429 4.047686
2 1.996251 1.850199 3.652173
3 2.001999 2.035176 3.612669
4 1.921844 1.931835 3.600355
5 1.987216 1.961323 3.606629
6 2.022372 2.102732 3.609494
7 1.900056 2.059208 3.581742
8 1.942305 1.965547 3.538015
9 1.954079 2.006257 3.593008
10 1.984677 1.771790 3.617556
11 2.040979 2.079904 3.575464
12 2.009052 1.739175 3.626755
13 2.014115 1.204614 3.582353
14 2.042148 1.747239 3.608723
15 2.113957 1.831483 3.684338
16 1.979398 1.923163 3.600483
17 1.996756 1.760739 3.635300
18 1.976695 1.982629 3.575843
19 2.088960 1.822936 3.617471
20 1.949941 1.996513 3.594223
21 2.079416 1.918284 3.588732
22 2.055047 1.869254 3.602390
23 1.860164 1.917518 3.557776
24 1.945440 2.033273 3.535242
25 2.026493 1.804196 3.558001
26 1.875208 1.797288 3.511697
27 1.972286 1.798044 3.570746
28 1.950635 1.951106 3.525849
29 2.013820 1.937439 3.592216
30 1.959477 1.959566 3.561970
31 2.012466 2.110288 3.539897
32 1.982466 1.905378 3.559940
33 1.957023 2.207354 3.540873
34 2.049188 1.942845 3.638360
35 1.913136 1.891638 3.581291
36 2.037127 1.808180 3.572567
37 2.006383 2.048738 3.553226
38 2.000312 1.657985 3.594805
39 1.973937 1.891186 3.533843
40 2.002513 1.853988 3.554688


In [ ]:
learn.save('gan-1c')

In [ ]:
learn.data=get_data(16,192)

In [ ]:
learn.fit(10,lr/2)


Total time: 43:07

epoch train_loss gen_loss disc_loss
1 2.578580 2.415008 4.716179
2 2.620808 2.487282 4.729377
3 2.596190 2.579693 4.796489
4 2.701113 2.522197 4.821410
5 2.545030 2.401921 4.710739
6 2.638539 2.548171 4.776103
7 2.551988 2.513859 4.644952
8 2.629724 2.490307 4.701890
9 2.552170 2.487726 4.728183
10 2.597136 2.478334 4.649708


In [ ]:
learn.show_results(rows=16)



In [ ]:
learn.save('gan-1c')

fin


In [ ]: