Building Autoencoders in Keras PyTorch

WNixalo – 2018/6/16-20

Building Autoencoders in Keras


In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch

In [3]:
import matplotlib.pyplot as plt

In [4]:
import copy

1. Building the simplest possible autoencoder

We'll start simple, with a single fully-connected neural layer as encoder and as decoder:

from keras.layers import Input, Dense
from keras.models import Model

# this is the size of our encoded representations
encoding_dim = 32 # 32 floats -> compression of factor 24.5 assuming input is 784 floats

# this is our input placeholder
input_img = Input(shape=(784,))
# "encoded" is the encoded representation of the input
encoded = Dense(encoding_dim, activation='relu')(input_img)
# "decoded" is the lossy reconstruction of the input
decoded = Dense(784, activation='sigmoid')(encoded)

# this model maps an input to its reconstruction
autoencoder = Model(input_img, decoded)

Let's also create a separate encoder model:

# this model maps an input to its encoded representation
encoder = Model(input_img, encoded)

As well as the decoder model:

# create a placeholder for an encoded (32-dimensional) input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layer of the autoencoder model
decoder_layer = autoencoder.layers[-1]
# create the decoder model
decoder = Model(encoded_input, decoder_layer(encoded_input))

In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [6]:
input_size   = 784
encoding_dim = 32

In [7]:
# writing full classes for the en/de-coders is overkill, but this is the 
# general form of writing pytorch modules.

class Encoder(nn.Module):
    def __init__(self, input_size, encoding_dim):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(input_size, encoding_dim), nn.ReLU()])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
            
class Decoder(nn.Module):
    def __init__(self, encoding_dim, input_size):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(encoding_dim, input_size)])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, input_size=784, encoding_dim=32):
        super().__init__()
        self.encoder = Encoder(input_size, encoding_dim)
        self.decoder = Decoder(encoding_dim, input_size)
    def forward(self, x):
        x = x.view(x.size(0), -1) # flatten to (nm,1) vector
        x = self.encoder(x)  # Encode
        x = self.decoder(x)  # Decode
        x = F.sigmoid(x)
        x = x.reshape(x.size(0),1,28,28) # should be a more elegant way to do this
        return x

In [8]:
autoencoder = Autoencoder(784, 32)

Now let's train our autoencoder to reconstruct MNIST digits.

First, we'll configure our model to use a per-pixel binary crossentropy loss, and the Adadelta optimizer:

autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

Let's prepare our input data. We're using MNIST digits, and we're discarding the labels (since we're only interested in encoding/decoding the input images).

from keras.datasets import mnist
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()

We will normalize all values between 0 and 1 and we will flatten the 28x28 images into vectors of size 784.

x_train = x_train.astype('float32') / 255.
x_test  = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test  = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
print(x_train.shape)
print(x_test.shape)

In [10]:
# # these are actually initialized within the training loop
# optimizer = torch.optim.Adadelta(autoencoder.parameters())
# criterion = torch.nn.BCELoss()

In [8]:
import torchvision

In [9]:
bs = 16

# stats = [[0.1307],[0.3073]] # calculated from training set

tfm0 = torchvision.transforms.ToTensor()  # convert [0,255] -> [0.0,1.0]
# tfm1 = torchvision.transforms.Normalize(*stats) # normalize to [-1.0,+1.0]

# tfms = [tfm0, tfm1]
# tfms = torchvision.transforms.Compose(tfms)

In [10]:
train_dataset = torchvision.datasets.MNIST('data/MNIST/',train=True, transform=tfm0)
test_dataset  = torchvision.datasets.MNIST('data/MNIST/',train=False,transform=tfm0)

train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader   = torch.utils.data.DataLoader(test_dataset,  batch_size=bs)

Aside: Flattening Tensors and Untrained Autoencoder output

multiple ways to flatten tensors:


In [10]:
# get minibatch
x,_ = next(iter(train_loader)); x_test,_ = next(iter(test_loader))
# check ea. way to flatten is identical
compare0 = torch.equal(np.reshape(x, (len(x), np.prod(x.shape[1:]))),
                       x.reshape(len(x), np.prod(x.shape[1:])))
compare1 = torch.equal(x.view(x.size(0), -1), x.reshape(len(x), np.prod(x.shape[1:])))
print(True == compare0 == compare1)
# display flattened minibatch shapes
print(x.view(x.size(0), -1).shape)
print(x_test.view(x_test.size(0), -1).shape)


True
torch.Size([16, 784])
torch.Size([16, 784])

Here's a test of the autoencoder without any training (you'd expect just noise):


In [13]:
x,y = next(iter(train_loader))
z = autoencoder(x)

In [14]:
z.shape


Out[14]:
torch.Size([16, 1, 28, 28])

In [197]:
fig,axes = plt.subplots(1,2); plt.set_cmap(['gray','viridis'][1]);
axes[0].imshow(x[0][0].numpy()); axes[1].imshow(z[0][0].detach().numpy());



In [11]:
def compare_plot(x,z, idx=0, cdx=1):
    """assumes x,y (bs,1,28,28) tensors"""
    fig,axes = plt.subplots(1,2); plt.set_cmap(['gray','viridis'][cdx]);
    if type(z) == torch.Tensor: z = z.detach().numpy()
    axes[0].imshow(x[idx][0].numpy()); axes[1].imshow(z[idx][0]);

Now let's train our autoencoder for 50 epochs:

autoencoder.fit(x_train, x_train, epochs=50, batch_size=256, shuffle=True,
                validation_data=(x_test, x_test))

After 50 epochs, the autoencoder seems to reach a stable train/test loss value of about 0.11. We can try to visualize the reconstrubted inputs and the encoded representations. We will use Matplotlib.

# encode and decode some digits
# note that we take them from the *test* set
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)
# use Matplotlib (don't ask)
import matplotlib.pyplot as plt

n = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [12]:
def train(model, trainloader=None, valloader=None, num_epochs=1):
    # use GPU0 if available # pytorch >= 0.4
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # name dataloaders for phases
    phases = ['train']
    dataloaders = {'train':trainloader}
    if valloader:
        phases.append('valid')
        dataloaders['valid'] = valloader
    
    # move model to gpu -- pytorch >= 0.4 ## NOTE: havent tested on gpu yet
    model.to(device)
    optimizer = torch.optim.Adadelta(autoencoder.parameters())
    criterion = torch.nn.BCELoss()
    
    # epoch w/ train & val phases
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}\n{"-"*10}')
        
        for phase in phases:
            if phase == 'train':
                #scheduler.step() # (no LR scheduler this time)
                model.train() # set model to training mode
            else: model.eval() # set model to evaluation mode
                
            running_loss, running_correct, count = 0.0, 0, 0
        
            for i, minibatch in enumerate(dataloaders[phase]):
                x,y = minibatch
                x,y = x.to(device), y.to(device)

                # zero param gradients
                optimizer.zero_grad()

                # forward: track history if training phase
                with torch.set_grad_enabled(phase=='train'): # pytorch >= 0.4
                    outputs = model(x)
                    loss    = criterion(outputs, x)
                    preds,_ = torch.max(outputs,1) # for accuracy metric
                    # backward & optimize if training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # stats
                running_loss += loss.item() * x.size(0)
                count += len(x)
            
            epoch_loss = running_loss / count
            print(f'{phase} loss {epoch_loss:.6f}')
        print()

In [205]:
%time train(autoencoder, trainloader=train_loader, valloader=test_loader, num_epochs=1)


Epoch 1/1
----------
CPU times: user 54.4 s, sys: 1.3 s, total: 55.7 s
Wall time: 19.7 s

In [211]:
x,y = next(iter(train_loader))
z = autoencoder(x)
compare_plot(x,z)


Woah.


Aside: Fast.AI

FChollet makes use of TensorBoard to make some fancy plots of the training process. Also, Keras isn't at PyTorch's abstraction level: PyTorch is comparable to TensorFlow, Keras' backend. So instead of doing all my training manually, I'm going to modify the pytorch dataloaders a bit so I can use them with fastai for training. This'll give me some powerful high-level control over the process.


In [13]:
# create copies of dataloaders for ModelData
train_loadermd = copy.deepcopy(train_loader)
test_loadermd  = copy.deepcopy(test_loader)

# set y to be x and convert [0,255] int to [0.0,1.0] float. (dl doesnt trsfm `y` by default)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_data.type(torch.FloatTensor)/255
test_loadermd.dataset.test_labels   = test_loadermd.dataset.test_data.type(torch.FloatTensor)/255

# add channel dimension for compatibility. (bs,h,w) –> (bs,ch,h,w)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_labels.reshape((len(train_loadermd.dataset),1,28,28))
test_loadermd.dataset.test_labels   = test_loadermd.dataset.test_labels.reshape((len(test_loadermd.dataset),1,28,28))

In [14]:
from fastai.conv_learner import *

In [15]:
md = ModelData('data/MNIST', train_loadermd, test_loadermd)

In [20]:
learn = Learner.from_model_data(Autoencoder(), md)
learn.crit = F.binary_cross_entropy

In [174]:
learn.lr_find()
learn.sched.plot()


epoch      trn_loss   val_loss                                  
    0      0.142973   0.154716  

In [175]:
z = learn.predict()        # run learner on val data
x = md.val_ds.test_data[0] # get X from val dataset
compare_plot([[x]], z) # expects dims: (batchsize, channels, rows, cols)



In [176]:
learn.fit(0.01, 1)     # fit learner to data
z = learn.predict()    # run learner on val data
compare_plot([[x]], z) # display X vs Z


epoch      trn_loss   val_loss                                  
    0      0.264682   0.262922  

In [188]:
x = md.val_ds.test_data.reshape(len(md.val_ds),1,28,28)
compare_plot(x, z, idx=10)



Now that I got fastai working with the data, I can easily train 50 epochs. Although, to show how quickly the model converges w/ MNIST data, here's the result of a couple epochs via pure-pytorch:


In [221]:
autoencoder = Autoencoder()
train(autoencoder, train_loader, test_loader, num_epochs=2)


Epoch 1/2
----------
train loss 0.200236
valid loss 0.148270

Epoch 2/2
----------
train loss 0.131812
valid loss 0.117702


In [223]:
x,y = next(iter(test_loader))
z = autoencoder(x)
compare_plot(x, z)


Now to train for 50 epochs with fastai:


In [242]:
autoencoder = Autoencoder()
learner = Learner.from_model_data(autoencoder, md)
learner.crit = F.binary_cross_entropy
learner.opt_fn = torch.optim.Adadelta

In [243]:
learner.lr_find()
learner.sched.plot()


epoch      trn_loss   val_loss                                 
    0      0.180425   0.171059  

Looks like I'll have to use a really aggressive learning rate just to get started anywhere. The torch.optim.Adadelta optimizer's default was 1.0 in the pytorch training loop. I'm training for 50 epochs anyway (to match the tutorial) so I'll go with 0.5. The tutorial is also using batch sizes of 256, but I'm using 16 (for no particular reason than not wanting to reinitialize everything).

Nothing special with training: just 50 cycles, each 1 epoch long, using default Cosine Annealing, no weight decay.


In [245]:
learner.fit(0.5, n_cycle=50)


epoch      trn_loss   val_loss                                 
    0      0.180503   0.179986  
    1      0.146636   0.145664                                 
    2      0.127929   0.126243                                 
    3      0.116537   0.114927                                 
    4      0.110755   0.108692                                 
    5      0.106966   0.105151                                 
    6      0.104011   0.102819                                 
    7      0.102981   0.101205                                 
    8      0.100336   0.099909                                  
    9      0.101887   0.099023                                   
    10     0.10073    0.098413                                   
    11     0.098454   0.097746                                   
    12     0.099941   0.097324                                   
    13     0.096887   0.096923                                   
    14     0.097922   0.096655                                   
    15     0.099135   0.096455                                   
    16     0.09719    0.096281                                   
    17     0.096672   0.096134                                   
    18     0.097642   0.095951                                   
    19     0.096808   0.095842                                   
    20     0.097703   0.095802                                   
    21     0.097251   0.0957                                     
    22     0.097903   0.095605                                   
    23     0.097281   0.095573                                   
    24     0.096383   0.095419                                   
    25     0.096412   0.095387                                   
    26     0.096422   0.095451                                   
    27     0.097631   0.095294                                   
    28     0.096548   0.095258                                   
    29     0.096357   0.095187                                   
    30     0.096116   0.095171                                   
    31     0.096966   0.095122                                   
    32     0.096267   0.095071                                   
    33     0.097147   0.094986                                   
    34     0.095909   0.094997                                   
    35     0.097563   0.094984                                   
    36     0.095372   0.094918                                   
    37     0.096651   0.09493                                    
    38     0.096898   0.094867                                  
    39     0.096978   0.094793                                   
    40     0.096873   0.094749                                   
    41     0.09536    0.094688                                   
    42     0.094865   0.094657                                   
    43     0.095538   0.094549                                   
    44     0.095497   0.094434                                   
    45     0.096441   0.094239                                   
    46     0.096006   0.094079                                   
    47     0.093832   0.09398                                    
    48     0.095412   0.09387                                    
    49     0.095988   0.093775                                   
Out[245]:
[0.09377471673488617]

In [246]:
learn.save('pytorch-autoencoder-50ep')

In [263]:
x,y = next(iter(md.val_dl)) # get 1st minibatch
z = learner.predict()

In [266]:
compare_plot(x,z)



In [16]:
def compare_batch(x, z, bs=16, figsize=(16,2)):
    bs = min(len(x), bs) # digits to display
    fig = plt.figure(figsize=figsize)
    for i in range(bs):
        # display original
        ax = plt.subplot(2, bs, i+1); ax.imshow(x[i].reshape(28,28))
        ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
        
        # display reconstruction
        ax = plt.subplot(2, bs, i+1+bs); ax.imshow(z[i].reshape(28,28))
        ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)

In [286]:
compare_batch(x,z)


Here's what we get. The top row is the original digits, and the bottom row is the reconstructed digits. We are losing quite a bit of detail with this basic approach.

2. Adding a sparsity constraint on the encoded representations

In the previous example, the representations were only constrained by the size of the hidden layer (32). In such a situation, what typically ahppens is that the hidden layer is learning an approximation of PCA (principal component analysis). But another way to constrain the representations to be compact is to add a sparsity constraint on the activity of the hidden representations, so fewer units would "fire" at a given time. In Keras, this can be done by adding an activity_regularizer to our Dense layer:

from keras import regularizers

encoding_dim = 32

input_img = Input(shape=(784,))
# add a Dense layer with a L1 activity regularizer
encoded = Dense(encoding_dim, activation='relu',
                activity_regularizer=regularizers.l1(10e-5))(input_img)
decoded = Dense(784, activation='sigmoid')(encoded)

autoencoder = Model(input_img, decoded)

Let's train this model for 100 epochs (with the added regularization the model is less likely to overfit and can be trained longer). The model ends with a train loss of 0.11 and a test loss of 0.10. The difference between the two is mostly due to the regularization term being added to the loss during training (worth about 0.01).

Fast AI's Learner has a built-in regularization function attribute.

Internally in fastai.model.Stepper in the step(.) function, the output is calculated by passing the input/s into the model self.m(.). If the output is a tuple, in the case of multi-headed models or models that also output intermediate activations, output is reassigned and destructured: output is reassigned to it's 1st item, and xtra is a list of all the rest.

raw_loss is first calculated on the output and y using the loss function self.crit. If no regularization function is attached to the Learner, the raw_loss is returned as the loss. Otherwise, output, xtra, and raw_loss are all passed into the regularizer self.reg_fn and the result is returned as loss.

So adding L1 (or any) regularization to a Fast AI Learner is as easy as defining a regularization function that accepts arguments: output, xtra, raw_loss and assigning it to learner.reg_fn. Also make sure the encoder sends a copy of it

This raises a further question of how exactly was multi-head / output work done, for example in pascal.ipynb (multi-head multi-output) and lesson2-image_models.ipynb (multi-output). In lesson2-image_models that abstraction layer is still hidden, but in pascal.ipynb the SSD model's OutConv nn.Module class outputs a list of [classifications, regressions], so I think that's the big clue. I tested this and the output indeed didn't get destructured.


In [17]:
# writing full classes for the en/de-coders is overkill, but this is the 
# general form of writing pytorch modules.

class EncoderL1(nn.Module):
    def __init__(self, input_size, encoding_dim):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(input_size, encoding_dim), nn.ReLU()])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
            
class Decoder(nn.Module):
    def __init__(self, encoding_dim, input_size):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(encoding_dim, input_size)])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class AutoencoderL1(nn.Module):
    """Basic autoencoder with extra encoder output for L1 regularization."""
    def __init__(self, input_size=784, encoding_dim=32):
        super().__init__()
        self.encoder = EncoderL1(input_size, encoding_dim)
        self.decoder = Decoder(encoding_dim, input_size)
    def forward(self, x):
        x = x.view(x.size(0), -1) # flatten to (nm,1) vector
        l1x = self.encoder(x)  # Encode
        x = self.decoder(l1x)  # Decode
        x = F.sigmoid(x)
        x = x.reshape(x.size(0),1,28,28) # should be a more elegant way to do this
        return x, l1x # autoencoder and encoder outputs
    
def l1_reg(output, xtra, raw_loss, λ1=1e-4):
    """Adds L1 Regularization to the Loss."""
    return λ1*torch.abs(*xtra).sum() + raw_loss # xtra comes in as list, deconstruct to tensor

In [28]:
learn = Learner.from_model_data(AutoencoderL1(), md)

In [29]:
learn.crit = F.binary_cross_entropy
learn.opt_fn = torch.optim.Adadelta
learn.reg_fn = l1_reg

In [30]:
learn.lr_find()
learn.sched.plot()


epoch      trn_loss   val_loss                                 
    0      0.246507   0.239414  

In [31]:
%time learn.fit(0.5, 100) # same lr as before; now 100 cycles (1 ep/cyc)


epoch      trn_loss   val_loss                                 
    0      0.250728   0.248423  
    1      0.231891   0.231998                                 
    2      0.224072   0.220726                                 
    3      0.218669   0.21725                                  
    4      0.217783   0.21568                                  
    5      0.214344   0.214774                                 
    6      0.214792   0.213973                                 
    7      0.214472   0.21356                                  
    8      0.215391   0.213021                                 
    9      0.214764   0.212651                                 
    10     0.213261   0.212288                                 
    11     0.215662   0.212235                                 
    12     0.215801   0.211687                                 
    13     0.213287   0.211712                                 
    14     0.212363   0.211242                                 
    15     0.214929   0.211148                                 
    16     0.211418   0.210893                                 
    17     0.213183   0.210403                                 
    18     0.210202   0.210322                                 
    19     0.212448   0.210095                                 
    20     0.214052   0.209838                                 
    21     0.212362   0.209638                                 
    22     0.21093    0.209292                                 
    23     0.210809   0.209161                                 
    24     0.208424   0.208805                                 
    25     0.211796   0.208757                                 
    26     0.209186   0.208177                                 
    27     0.202112   0.203661                                 
    28     0.200053   0.198871                                 
    29     0.197028   0.197766                                 
    30     0.198351   0.197138                                 
    31     0.195027   0.196505                                 
    32     0.201171   0.196194                                 
    33     0.196294   0.195889                                 
    34     0.196206   0.195579                                 
    35     0.194749   0.195337                                 
    36     0.197545   0.195441                                 
    37     0.195938   0.194983                                 
    38     0.195462   0.194897                                 
    39     0.195252   0.194728                                 
    40     0.19517    0.194871                                 
    41     0.195441   0.194452                                 
    42     0.19506    0.194397                                 
    43     0.197389   0.194431                                 
    44     0.193827   0.194218                                 
    45     0.194498   0.193927                                 
    46     0.195605   0.194014                                 
    47     0.195702   0.193939                                 
    48     0.195721   0.193727                                 
    49     0.195037   0.193647                                  
    50     0.196885   0.193479                                  
    51     0.196552   0.193214                                 
    52     0.188157   0.193239                                 
    53     0.196309   0.193157                                 
    54     0.194332   0.19309                                  
    55     0.195321   0.193134                                  
    56     0.192788   0.192848                                  
    57     0.19455    0.193147                                  
    58     0.193524   0.192736                                  
    59     0.193172   0.19251                                   
    60     0.194228   0.193028                                  
    61     0.194252   0.192686                                  
    62     0.193488   0.192514                                  
    63     0.193046   0.192446                                  
    64     0.195129   0.192356                                  
    65     0.194313   0.192707                                  
    66     0.193202   0.192403                                  
    67     0.19404    0.192078                                  
    68     0.190446   0.192127                                  
    69     0.194587   0.192071                                  
    70     0.193974   0.192075                                  
    71     0.195237   0.192065                                  
    72     0.194764   0.192281                                  
    73     0.192829   0.19189                                   
    74     0.193417   0.192164                                  
    75     0.192915   0.191963                                  
    76     0.192471   0.191633                                  
    77     0.192812   0.191628                                  
    78     0.192207   0.191699                                  
    79     0.191981   0.191719                                  
    80     0.194317   0.191467                                  
    81     0.193124   0.191468                                  
    82     0.19371    0.191475                                  
    83     0.193718   0.191594                                  
    84     0.194073   0.191415                                 
    85     0.194182   0.191322                                 
    86     0.193363   0.191507                                  
    87     0.189904   0.191335                                 
    88     0.192567   0.191504                                 
    89     0.193546   0.191206                                 
    90     0.195682   0.191548                                  
    91     0.191894   0.191569                                  
    92     0.19303    0.191532                                 
    93     0.191095   0.190958                                 
    94     0.190473   0.190923                                 
    95     0.191526   0.190989                                 
    96     0.192925   0.1911                                   
    97     0.194582   0.190933                                 
    98     0.19148    0.190934                                 
    99     0.192061   0.190795                                 
CPU times: user 2h 21min 21s, sys: 13min 24s, total: 2h 34min 46s
Wall time: 1h 13min 6s
Out[31]:
[0.190795294547081]

I think my learning rate was too low with the added regularization. And I just realized since this is pytorch I could have the model output the encoder outputs for regularization only when in training mode. hmm, next time.


In [36]:
learn.save('autoencoder_l1reg_100ep')

In [69]:
x = next(iter(learn.data.val_dl))[0]
z = learn.predict()

In [72]:
z.shape, x.shape


Out[72]:
((10000, 1, 28, 28), torch.Size([16, 1, 28, 28]))

In [71]:
compare_batch(x, z)


Here's a visualization of our new results:

They look pretty similar to the previous model, the only significant difference being the sparsity of the encoded representations. encoded_imgs.mean() yields a value 3.33 (over our 10,000 test images), whereas with the previous model the same quantity was 7.30. So our new model yields encoded representations that are twice sparser.


In [104]:
encodings = [] # now its convenient Im getting the encoder's output
for x,_ in iter(learn.data.val_dl):
    encodings.append(learn.model(x)[1].detach().numpy().mean())
np.array(encodings).mean()


Out[104]:
0.18292214

3. Deep autoencoder

We do not have to limit ourselves to a single layer as encoder or decoder, we could instead use a stack of layers such as:

input_img = Input(shape(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64,  activation='relu')(encoded)
encoded = Dense(32,  activation='relu')(encoded)

decoded = Dense(64,  actication='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)

Let's try this:

autoencoder = Model(inpu_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

autoencoder.fit(x_train, x_train, epochs=100, batch_size=256, shuffle=True, 
                validation_data=(x_test, x_test))

In [18]:
# writing with list comprehensions easily allows for arbitrary-depth networks
# Also I totally forgot the ReLU activations.
class EncoderBlock(nn.Module):
    def __init__(self, input_size=784, encoding_dim=32, n_layers=3):
        super().__init__()
        layers =  [nn.Linear(input_size, encoding_dim*2**(n_layers-1))]
        layers += [nn.Linear(encoding_dim*2**(i), encoding_dim*2**(i-1)) for i in range(n_layers-1,0,-1)]
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)

class DecoderBlock(nn.Module):
    def __init__(self, output_size=784, encoding_dim=32, n_layers=3):
        super().__init__()
        layers =  [nn.Linear(encoding_dim*2**(i), encoding_dim*2**(i+1)) for i in range(n_layers-1)]
        layers += [nn.Linear(encoding_dim*2**(n_layers-1), output_size)]
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)

class DeepAutoencoder(nn.Module):
    def __init__(self, input_size=784, encoding_dim=32, n_layers=3):
        super().__init__()
        self.encoder = EncoderBlock(input_size, encoding_dim, n_layers)
        self.decoder = DecoderBlock(input_size, encoding_dim, n_layers)
    def forward(self, x):
        x = x.view(x.size(0), -1)     # flatten
        enc_x = self.encoder(x)          # encode
        x = self.decoder(enc_x)          # decode
        x = F.sigmoid(x)                 # activtn
        x = x.reshape(x.size(0),1,28,28) # 'unflatten'
        return x, enc_x # also return encoding

In [96]:
learn = Learner.from_model_data(DeepAutoencoder(), md)

In [97]:
learn.crit  = F.binary_cross_entropy
learn.optim = torch.optim.Adadelta

In [98]:
learn.lr_find()
learn.sched.plot()


 86%|████████▋ | 3238/3750 [00:54<00:08, 59.78it/s, loss=0.751]
 86%|████████▋ | 3238/3750 [01:05<00:10, 49.61it/s, loss=0.751]

In [99]:
learn.fit(0.5, 100)


epoch      trn_loss   val_loss                                  
    0      0.120986   0.117343  
    1      0.109043   0.10734                                   
    2      0.104469   0.102813                                  
    3      0.101269   0.099745                                   
    4      0.100716   0.098699                                   
    5      0.099753   0.09749                                   
    6      0.097745   0.096033                                  
    7      0.097524   0.095729                                   
    8      0.096915   0.096204                                   
    9      0.097093   0.094898                                   
    10     0.096377   0.094045                                   
    11     0.095779   0.094725                                   
    12     0.095585   0.094494                                   
    13     0.096055   0.093822                                   
    14     0.094103   0.09436                                    
    15     0.095183   0.093865                                   
    16     0.095914   0.094332                                   
    17     0.094502   0.094081                                   
    18     0.095473   0.094105                                  
    19     0.094822   0.093715                                  
    20     0.095561   0.093849                                   
    21     0.095662   0.093837                                   
    22     0.094887   0.093934                                  
    23     0.096861   0.094294                                  
    24     0.096279   0.093991                                   
    25     0.095046   0.094371                                   
    26     0.095063   0.094341                                   
    27     0.095012   0.094231                                   
    28     0.094759   0.094249                                   
    29     0.095135   0.093953                                   
    30     0.096162   0.093863                                   
    31     0.094675   0.093906                                   
    32     0.095107   0.094484                                   
    33     0.095015   0.09429                                    
    34     0.096256   0.094946                                   
    35     0.095801   0.094216                                   
    36     0.094839   0.094047                                   
    37     0.095284   0.094244                                   
    38     0.094703   0.093797                                   
    39     0.096912   0.094259                                   
    40     0.095523   0.094349                                   
    41     0.095702   0.094206                                   
    42     0.095962   0.093802                                   
    43     0.095005   0.094155                                   
    44     0.093467   0.094107                                   
    45     0.096025   0.093811                                   
    46     0.095721   0.094075                                   
    47     0.094977   0.093953                                   
    48     0.095045   0.094202                                   
    49     0.096944   0.094219                                   
    50     0.09627    0.093771                                   
    51     0.096014   0.0935                                     
    52     0.096414   0.093339                                   
    53     0.095508   0.093996                                   
    54     0.093895   0.094373                                   
    55     0.094883   0.094055                                   
    56     0.094829   0.094283                                   
    57     0.096503   0.094141                                   
    58     0.095673   0.094654                                   
    59     0.096153   0.093631                                   
    60     0.095549   0.0943                                     
    61     0.095763   0.094089                                   
    62     0.095088   0.093627                                   
    63     0.094691   0.094741                                   
    64     0.095099   0.093572                                   
    65     0.094725   0.093783                                   
    66     0.095838   0.094093                                   
    67     0.095692   0.094063                                   
    68     0.095118   0.094092                                  
    69     0.096006   0.094204                                  
    70     0.094288   0.094402                                  
    71     0.094026   0.093477                                  
    72     0.096448   0.094191                                  
    73     0.096163   0.093997                                  
    74     0.094648   0.093809                                  
    75     0.094369   0.094059                                  
    76     0.09508    0.094292                                  
    77     0.095836   0.093932                                  
    78     0.094883   0.09442                                    
    79     0.095707   0.093925                                   
    80     0.095647   0.094595                                   
    81     0.096259   0.093811                                   
    82     0.09501    0.093903                                   
    83     0.093533   0.094032                                   
    84     0.094751   0.09411                                    
    85     0.094731   0.093534                                   
    86     0.09654    0.094085                                   
    87     0.095779   0.094024                                   
    88     0.095645   0.093932                                   
    89     0.095365   0.093965                                   
    90     0.095116   0.093902                                   
    91     0.095641   0.094386                                   
    92     0.096589   0.094478                                   
    93     0.096052   0.093937                                   
    94     0.09489    0.0937                                     
    95     0.095678   0.093895                                   
    96     0.09534    0.09406                                    
    97     0.09602    0.093873                                   
    98     0.095944   0.094997                                   
    99     0.095098   0.093853                                   
Out[99]:
[0.09385320316553115]

In [100]:
learn.save('autoencoder_deep_100ep')

In [102]:
x,y = next(iter(learn.data.val_dl))
z   = learn.predict()
compare_batch(x,z)


After 100 epochs, it reaches a train and test loss of ~0.097, a bit better than our previous models. Our reconstructed digits look a bit better too:

4. Convolutional autoencoder

Since our inputs are images, it makes sense to use convolutional neural networks (convnets) as encoders and decoders. In practical settings, autoencoders applied to images are always convolutional autoencoders –– they simple perform much better.

Let's implement one. The encoder will consist in a stack of Conv2D and MaxPooling2D layers (max pooling being used for spatial down-sampling), while the decoder will conssit in a stack of Conv2D and UpSampling2D layers.

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K

input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format

x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D(2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu' padding='same')(x)
x = MaxPooling2D(2, 2), padding='same'(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D(2, 2), padding='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2d(2, 2)(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

To train it, we will use the original MNIST digits with shape (samples, 3, 28, 28), and we will just normalize pixel values between 0 and 1.

from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) # adapt this if using `channels_first` image data format
x+test = np.reshape(x_test, (len(x_test), 28, 28, 1)) # adapt this if using `channels_first` image data format

Let's train this model for 50 epochs. For the sake of demonstrating how to visualize the results of a model during training, we will be using the TensorFlow backend and the TensorBoard callback.

First let's open up a terminal and start a TensorBoard server that will read logs stored at /tmp/autoencoder.

Then let's train our model. In the callbacks list we pass an instance of the TensorBoard callback. After every epoch, this callback will write logs to /tmp/autoencoder, which can be read by our TensorBoard server.

from keras.callbacks import TensorBoard

autoencoder.fit(x_train, x_train, epochs=50, batch_size=128, shuffle=True, 
                validation_data=(x_test, x_test), 
                callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])

This allows us to monitor training in the TensorBoard web interface (by navigating to http://0.0.0.0:6006):


In [19]:
class ConvEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 16, 3)
        self.conv1 = nn.Conv2d(16, 8, 3)
        self.conv2 = nn.Conv2d(8,  8, 3)
    def forward(self, x):
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv0(x))                 # conv & actvn (1,16,28,28)
        x = F.max_pool2d(x, 2)                    # pool         (1,16,14,14)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv1(x))                 # conv & actvn (1, 8,14,14)
        x = F.max_pool2d(x, 2)                    # pool         (1, 8, 7, 7)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv2(x))                 # conv & actvn (1, 8, 7, 7)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.max_pool2d(x, 2)                    # pool         (1, 8, 4, 4)
        return x
    
class ConvDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(8, 8, 3)
        self.conv1 = nn.Conv2d(8, 8, 3)
        self.conv2 = nn.Conv2d(8, 16,3)
        self.conv3 = nn.Conv2d(16, 1,3)
        self.upsample = nn.Upsample(scale_factor=2)
    def forward(self, x):
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv0(x))                 # conv & actvn (1, 8, 4, 4)
        x = self.upsample(x)                      # upsample     (1, 8, 8, 8)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv1(x))                 # conv & actvn (1, 8, 8, 8)
        x = self.upsample(x)                      # upsample     (1, 8,16,16)
        x = F.relu(self.conv2(x))                 # conv & actvn (1,16,14,14)
        x = self.upsample(x)                      # upsample     (1,16,28,28)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.sigmoid(self.conv3(x))              # conv & actvn (1, 1,28,28)
        return x

class ConvAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ConvEncoder()
        self.decoder = ConvDecoder()
    def forward(self, x):
        enc_x = self.encoder(x)
        x = self.decoder(enc_x)
        return x, enc_x

I'm going to use the same MNIST data I've been using.


In [306]:
learn = Learner.from_model_data(ConvAutoencoder(), md)

In [307]:
learn.opt_fn = torch.optim.Adadelta
learn.crit   = F.binary_cross_entropy

In [308]:
learn.lr_find()
learn.sched.plot(n_skip_end=0)


epoch      trn_loss   val_loss                                 
    0      0.194118   0.198236  

In [318]:
%time learn.fit(0.1, 50) # 50 1-epoch cycles at lr=0.1


epoch      trn_loss   val_loss                                 
    0      0.150967   0.150391  
    1      0.12907    0.133564                                 
    2      0.122344   0.119396                                 
    3      0.114795   0.113412                                 
    4      0.113377   0.112516                                 
    5      0.109648   0.108521                                 
    6      0.1075     0.106346                                 
    7      0.105783   0.105214                                 
    8      0.104533   0.102116                                 
    9      0.102351   0.101962                                 
    10     0.103103   0.101578                                 
    11     0.101052   0.100249                                  
    12     0.101315   0.099082                                  
    13     0.09977    0.098173                                  
    14     0.100634   0.097466                                  
    15     0.099024   0.098431                                  
    16     0.096336   0.096902                                  
    17     0.097631   0.09617                                   
    18     0.097349   0.095814                                  
    19     0.097988   0.098667                                  
    20     0.096379   0.095681                                  
    21     0.096793   0.095878                                  
    22     0.095395   0.094531                                  
    23     0.097244   0.093837                                  
    24     0.095563   0.093515                                  
    25     0.094987   0.094125                                  
    26     0.094184   0.093576                                  
    27     0.094521   0.094911                                  
    28     0.094767   0.094591                                  
    29     0.093211   0.092355                                  
    30     0.09345    0.09271                                   
    31     0.093887   0.092072                                  
    32     0.095243   0.092447                                  
    33     0.093832   0.091977                                  
    34     0.093866   0.092095                                  
    35     0.092937   0.092397                                  
    36     0.094376   0.092509                                  
    37     0.093592   0.091654                                  
    38     0.092586   0.091105                                  
    39     0.093054   0.091199                                  
    40     0.093674   0.091052                                  
    41     0.092521   0.090781                                  
    42     0.092303   0.091151                                  
    43     0.09286    0.091289                                  
    44     0.0917     0.090959                                  
    45     0.092256   0.090424                                  
    46     0.091967   0.091158                                  
    47     0.09139    0.090865                                  
    48     0.090712   0.092501                                  
    49     0.09053    0.090733                                  
CPU times: user 5h 38min 44s, sys: 50min 14s, total: 6h 28min 59s
Wall time: 2h 27min 1s
Out[318]:
[0.09073325432538987]

In [319]:
learn.save('autoencoder_conv_50ep')

In [320]:
plt.style.use('seaborn')
fig = plt.figure(figsize=(12,6));
ax = plt.subplot(1, 2, 1); ax.plot(learn.sched.losses); ax.set_title('loss')
ax = plt.subplot(1, 2, 2); ax.plot(learn.sched.val_losses); ax.set_title('val_loss');



In [321]:
x,y = next(iter(learn.data.val_dl))
z   = learn.predict()
compare_batch(x,z)



In [322]:
plt.style.use('default')
compare_batch(x,z)


The model converges to a loss of 0.094, significantly better than our previous models (this is in large part due to the higher entropic capacity of the encoded representation, 128 deimensions vs. 32 previously). Let's take a look at the reconstructd digits:

We can also have a look at the 128-dimensional encoded representations. These representations are 8x4x4, so we reshape them to 4x32 in order to eb able to display them as grayscale images.


In [323]:
encodings = learn.model(x)[1]

In [324]:
encodings.shape


Out[324]:
torch.Size([16, 8, 4, 4])

In [336]:
for i in range(bs):
    ax = plt.subplot(1, bs, i+1); ax.imshow(encodings[i].reshape(4,4*8).detach().numpy().T)
    ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)


5. Application to image denoising

Let's put our convolutional autoencoder to work on an image denoising problem. It's simple: we will train the autoencoder to map noisy digits images to clean digits images.

Here's how we will generate synthetic noisy digits: we just apply a gaussian noise matrix and clip the images between 0 and 1.

from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) # adapt this if using `channels_first` data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)) # adapt this if using `channels_first` image data format

noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)

x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

Here's what the noisy digits look like:

Luckily for me, thanks to Fast.AI and PyTorch, all I need to do is update the dataset's transform and I can get straight back to work. I just need a transform that adds noise. I can use Pytorch's Lambda transform to define my own. I guess I should use pytorch's torch.randn function instead of numpy's.

On second thought, aren't transforms done on the cpu anyway? I'm trying to think about writing generalizable code; in anycase, getting more comfortable with pytorch is a good thing.


In [388]:
def add_noise(tensor, noise_factor=0.3):
    return tensor + noise_factor * torch.randn(tensor.shape)

tfm0 = torchvision.transforms.ToTensor()
tfm1 = torchvision.transforms.Lambda(add_noise)

tfms = torchvision.transforms.Compose([tfm0, tfm1])

learn.data.trn_ds.transform = tfms
learn.data.val_ds.transform = tfms

And let's take a look at our new noisified data (0.5 noise factor):


In [382]:
x,y = next(iter(learn.data.trn_dl))
compare_plot(x,y)


Okay.. way too much noise 😅

Here's noise factor 0.3:


In [389]:
x,y = next(iter(learn.data.trn_dl))
compare_plot(x,y)



In [390]:
compare_plot(*next(iter(learn.data.val_dl)))


If you squint you can still recognize them, but barely. Can our autoencoder learn to recover the original digits? Let's find out.

Compared to the previous convolutional autoencoder, in order to improve the quality of the reconstructed, we'll use a slighly different model with more filters per layer:

input_img = Input(shape=(28, 28, 1) # adapt this if using `channels_first` image data format

x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (7, 7, 32)

x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

My Mac is going to love this.


In [20]:
class ConvEncoderDenoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1,  32, 3)
        self.conv1 = nn.Conv2d(32, 32, 3)
    def forward(self, x):
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv0(x))                 # conv & actvn (bs,32,28,28)
        x = F.max_pool2d(x, 2)                    # pool         (bs,32,14,14)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv1(x))                 # conv & actvn (bs, 32,14,14)
        x = F.max_pool2d(x, 2)                    # pool         (bs, 32, 7, 7)
        return x
    
class ConvDecoderDenoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(32, 32, 3)
        self.conv1 = nn.Conv2d(32, 32, 3)
        self.conv2 = nn.Conv2d(32, 1,  3)
        self.upsample = nn.Upsample(scale_factor=2)
    def forward(self, x):
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv0(x))                 # conv & actvn (bs,32, 7, 7)
        x = self.upsample(x)                      # upsample     (bs,32,14,14)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.relu(self.conv1(x))                 # conv & actvn (bs,32,14,14)
        x = self.upsample(x)                      # upsample     (bs,32,28,28)
        x = F.pad(x, (1,1,1,1), mode='replicate') # pad
        x = F.sigmoid(self.conv2(x))              # conv & actvn (bs, 1,28,28)
        return x

class ConvAutoencoderDenoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ConvEncoderDenoise()
        self.decoder = ConvDecoderDenoise()
    def forward(self, x):
        enc_x = self.encoder(x)
        x = self.decoder(enc_x)
        return x, enc_x

In [393]:
learn = Learner.from_model_data(ConvAutoencoderDenoise(), md)
learn.crit = F.binary_cross_entropy
learn.opt_fn = torch.optim.Adadelta

In [394]:
learn.lr_find()
learn.sched.plot()


epoch      trn_loss   val_loss                                 
    0      0.125226   0.137232  

I did something with pyplot and it made the plots bigger. Ah well. 0.2 Learning rate looks alright. FChollet does this for a 100 epochs.. since I don't want to save-gitpush-gitpull-restart this notebook to run on a AWS GPU instance, I'll leave this to run through the night on my Mac. hopefully that's not an irresponsible way to treat my hardware...


In [396]:
learn.fit(0.2, 100) # lr=0.2 for 100 1-epoch cycles


epoch      trn_loss   val_loss                                  
    0      0.089942   0.089385  
    1      0.084569   0.084538                                  
    2      0.083804   0.083029                                  
    3      0.082711   0.081791                                  
    4      0.082448   0.080888                                  
    5      0.081158   0.080789                                  
    6      0.079876   0.080174                                  
    7      0.081274   0.079656                                  
    8      0.080663   0.07931                                   
    9      0.079917   0.07908                                   
    10     0.080253   0.07902                                   
    11     0.080176   0.07879                                   
    12     0.079051   0.078726                                  
    13     0.078867   0.078639                                  
    14     0.079057   0.07856                                   
    15     0.078319   0.078535                                  
    16     0.079401   0.078097                                  
    17     0.079364   0.078115                                  
    18     0.078791   0.077935                                  
    19     0.078743   0.077841                                  
    20     0.078773   0.07779                                   
    21     0.078692   0.077821                                  
    22     0.078263   0.077632                                  
    23     0.077804   0.077952                                  
    24     0.07744    0.077491                                  
    25     0.078213   0.077395                                  
    26     0.078004   0.077425                                  
    27     0.077568   0.077394                                  
    28     0.077388   0.077376                                  
    29     0.078431   0.077374                                  
    30     0.076934   0.077245                                  
    31     0.077436   0.077186                                  
    32     0.077925   0.077146                                  
    33     0.077818   0.077309                                  
    34     0.076843   0.078006                                  
    35     0.076727   0.077505                                  
    36     0.07683    0.077108                                  
    37     0.077316   0.077173                                  
    38     0.077904   0.077008                                  
    39     0.077381   0.07701                                   
    40     0.07775    0.077174                                  
    41     0.077305   0.076991                                  
    42     0.077165   0.076856                                  
    43     0.077825   0.076888                                  
    44     0.078132   0.076781                                  
    45     0.077043   0.077076                                  
    46     0.077111   0.076788                                  
    47     0.076162   0.076762                                  
    48     0.076943   0.077167                                  
    49     0.078013   0.076825                                  
    50     0.077661   0.076953                                  
    51     0.077129   0.07666                                   
    52     0.076765   0.076992                                  
    53     0.077533   0.076651                                  
    54     0.077461   0.076565                                  
    55     0.078009   0.076567                                  
    56     0.076814   0.076614                                  
    57     0.076488   0.07653                                   
    58     0.076694   0.076569                                  
    59     0.076904   0.076628                                  
    60     0.076574   0.076708                                  
    61     0.076621   0.076784                                  
    62     0.077705   0.076407                                  
    63     0.076608   0.076533                                  
    64     0.076986   0.076622                                  
    65     0.076877   0.076439                                  
    66     0.077435   0.076429                                  
    67     0.077649   0.076446                                  
    68     0.076227   0.076351                                  
    69     0.077574   0.076416                                  
    70     0.076648   0.076478                                  
    71     0.076773   0.076327                                  
    72     0.076983   0.076717                                  
    73     0.076441   0.076358                                  
    74     0.0764     0.076322                                  
    75     0.076812   0.076267                                  
    76     0.076823   0.076357                                  
    77     0.076267   0.076489                                  
    78     0.076314   0.076298                                  
    79     0.076727   0.076156                                  
    80     0.077315   0.076206                                  
    81     0.075834   0.07611                                   
    82     0.0769     0.076177                                  
    83     0.076198   0.076134                                  
    84     0.077615   0.076296                                  
    85     0.076271   0.076204                                  
    86     0.076938   0.076176                                  
    87     0.076326   0.076086                                  
    88     0.076027   0.076183                                  
    89     0.076548   0.076214                                  
    90     0.07622    0.076028                                  
    91     0.076817   0.076264                                  
    92     0.076677   0.076136                                  
    93     0.076821   0.07613                                   
    94     0.076818   0.076081                                  
    95     0.076399   0.07612                                   
    96     0.077422   0.07627                                   
    97     0.076345   0.076072                                  
    98     0.076574   0.076279                                  
    99     0.076597   0.076707                                  
Out[396]:
[0.07670741091966629]

In [397]:
learn.save('autoencoder_conv_denoise_100ep')

In [398]:
x,y = next(iter(learn.data.val_dl))
z = learn.predict()
compare_batch(x, z)



In [399]:
# Higher noise:
def add_noise(tensor, noise_factor=0.5):
    return tensor + noise_factor * torch.randn(tensor.shape)

tfm0 = torchvision.transforms.ToTensor()
tfm1 = torchvision.transforms.Lambda(add_noise)
tfms = torchvision.transforms.Compose([tfm0, tfm1])

learn.data.trn_ds.transform = tfms
learn.data.val_ds.transform = tfms

x,y = next(iter(learn.data.val_dl))
z = learn.predict()
compare_batch(x,z)


It seems to work pretty well. If you scale this process to a bigger convnet, you can start building document denoising or audio denoising models. Kaggle has an interesting dataset to get you started.

6. Sequence–to–sequence autoencoder

If your inputs are sequences, rather than vectors or 2D images, then you may want to use as encoder and decoder a type of model that cna capture temporal structure, such as a LSTM. To build a LSTM-based autoencoder, first use a LSTM encoder to turn your input sequences into a single vector that contains information about the entire sequence, then repeat this vector n times (where n is the number of timesteps in the output sequence), and run a LSTM decoder to turn this constant sequence into the target sequence.

We won't by demonstrating that one on any specific dataset. We will just put a code example for future reference for the reader!

from keras.layers import Input, LSTM, RepeatVector
from keras.models import Model

inputs = Input(shape=(timesteps, input_dim)
encoded = LSTM(latent_dim)(inputs)(inputs)

decoded = RepeatVector(timesteps)(encoded)
decoded = LSTM(input_dim, return_sequences=True)(decoded)

sequence_autoencoder = Model(inputs, decoded)
encoder = Model(inputs, encoded)

I don't know how the 'repeating' of Keras RNNs translates into Pytorch RNNs: whether that's simply a stack of RNNs atop of one another (pytorch: num_layers=n), or what.

FChollet's RNN starts with a shape: (timesteps, input_dim), so if we go by batch that's (timesteps, bs, 28, 28, 1), and it outputs a shape: (latent_dim).

But, the way he has his encoded tensor defined: it's an LSTM outputting shape (latent_dim) applied to inputs, applied to inputs.

So that means he's applying the RNN to the input twice.. but then what's the purpose of timesteps? So there's a difference between a timestep and how many times an RNN runs on a tensor? I thought that was the same thing.

The decoded tensor is the result of applying RepeatVector –– which acc. to Keras' docs just repeats the input n times. So given an input tensor of (n_samples, features) and n, it'll return a tensor of (n_samples, n, features). hmm –– with n = timesteps to the encoded tensor.. meaning decoded is now of shape (batchsize, timesteps, latent_dim) I think..

The decoded tensor then becomes the result of applying an LSTM with output shape input_dim and set to return_sequences=True on itself.

Then the full sequence autoencoder is a Model wrapper applied to the input and that final decoded tensor; the encoder is to input & the encoded tensor.

...Right.

I'm going to come back to this after I have more practice with RNNs in Pytorch.

7. Variational autoencoder (VAE)

Variational autoencoders are a slightly more modern and interesting take on autoencoding.

What is a variational autoencoder, you ask? It's a type of autoencoder with added constraints on the encoded representations being learned. More precisely, it is an autoencoder that learns a latent variable model for its input data. So intead of letting your neural network learn an arbitrary function, you are learning the parameters of a probability distribution modeling your data. If you sample points from this distribution, you can generate new input data samples: a VAE is a "generative model".

How does a variational autoencoder work?

First, an encoder network turns the input samples x into two parameters in a latent space, which we will note z_mean and z_log_sigma. Then, we randomly sample similar points z from the latent normal distribution that is assumed to generate the data, via z = z_mean + exp(z_log_sigma) * epsilon, where the epsilon is a random normal tensor. Finally, a decoder network maps these latent space points back to the original input data.

The parameters of the model are trained via two loss functions: a reconstruction loss forcing the decoded samples to match the initial inputs (just like in our previous autoencoders), and the KL divergence between the learned latent distribution and the prior distribution, acting as a regularization term. You could actually get rid of this latter term entirely, although it does help in learning well-formed latent spaces and reducing overfitting to the training data.

Because a VAE is a more complex example, we have made the code available on Github as a standalone script. Here we will review step by step how the model is created.

First, here's our encoder network, mapping inputs to our latent distribution parameters:

x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)

We can use these parameters to sample new similar points from the latent space:

def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., std=epsilon_std)
    return z_mean + K.exp(z_log_sigma) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_sigma])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])

Finally, we can map these sampled latent points back to the reconstructed inputs:

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

What we've done so far allows us to instantiate 3 models:

  • an end-to-end autoencoder mapping inputs to reconstructions
  • an encoder mapping inputs to the latent space
  • a generator that can take points on the latent space and will output the corresponding reconstructed samples.
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)

# encoder, from inputs to latent space
encoder = Model(x, z_mean)

# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

We train the model using the end-to-end model, with a custom loss function: the sum of a reconstruction term, and the KL divergence regularization term.

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

So this is a VAE with simple linear layers? Yay my laptop will enjoy that (I forgot to time it, but the Conv-denoising 100 epoch training session took... hours).


In [21]:
class VEncoder(nn.Module):
    """Returns intermediate encodings, mean, and log(stdev) tensors."""
    def __init__(self, input_size, interm_size, latent_size):
        super().__init__()
        self.intermediate = nn.Linear(input_size,  interm_size)
        self.mean_layer   = nn.Linear(interm_size, latent_size)
        self.stdv_layer   = nn.Linear(interm_size, latent_size)
    
    def forward(self, x):
        x     = F.relu(self.intermediate(x))
        μ     = F.relu(self.mean_layer(x)) # Mean vector
        log_σ = F.relu(self.stdv_layer(x)) # Stdv vector
        return x, μ, log_σ

class VSampler(nn.Module):
    """
        Multiplies standard deviation vector by a ~N(0,1) Gaussian distribution.
        Returns mean + new stdev.
        For theory see: https://youtu.be/uaaqyVS9-rM?t=19m42s
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        μ, log_σ = x 
        std_norm = torch.randn(μ.shape) # should I set `requires_grad=True`?
        return μ + torch.exp(log_σ)*std_norm

class VDecoder(nn.Module):
    """Decodes sampled """
    def __init__(self, output_size, interm_size, latent_size):
        super().__init__()
        self.intermediate = nn.Linear(latent_size, interm_size)
        self.out          = nn.Linear(interm_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.intermediate(x))
        x = F.sigmoid(self.out(x))
        return x
        
class VariationalAutoencoder(nn.Module):
    def __init__(self, orign_shape=784, interm_shape=32, latent_shape=16):
        super().__init__()
        self.encoder = VEncoder(orign_shape, interm_shape, latent_shape)
        self.sampler = VSampler()
        self.decoder = VDecoder(orign_shape, interm_shape, latent_shape)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)         # flatten
        enc_x, *μ_log_σ = self.encoder(x) # encode
        x = self.sampler(μ_log_σ)         # sample
        x = self.decoder(x)               # decode
        x = x.reshape(x.size(0),1,28,28)  # 'unflatten' -- could I use x.view(..)?
        return x, μ_log_σ, enc_x

FastAI splits the models output, and calculates the 'raw loss' (via the learner's criterion on the 1st element. If a regularization function is available, the loss is calculated by passing the 1st element of the output, the rest of the output, and the raw loss to that function. (see Stepper.step in fastai/model.py)

So the custom loss function I want has to be compatible with that. I can't just assign my criterion to a function because apparently KL divergence requires the Mean & LogStdev vectors computed by the encoder.

I can have my VAE's forward function hold on to those values and output them, and just deconstruct accordingly in my loss function. BCE will've already've been calculated as 'raw loss', so I can just add that to the KL divergence.


In [22]:
# for kl loss code see: https://wiseodd.github.io/techblog/2017/01/24/vae-pytorch/
# another kl loss (-Σ instead of +Σ: https://github.com/pytorch/examples/blob/master/vae/main.py#L77
# way from keras: https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py#L183

def vae_loss(z, xtra, raw_loss):
    μ_log_σ, _ = xtra
    μ, log_σ   = μ_log_σ
    reconstruction_loss = raw_loss
    kl_divergence_loss  = 0.5 * torch.sum(torch.exp(log_σ) + μ**2 - 1. - log_σ)
    return reconstruction_loss + kl_divergence_loss

In [25]:
# reset ModelData's dataset to not be noisy
tfm0 = torchvision.transforms.ToTensor()
md.trn_ds.transform = md.val_ds.transform = tfm0

In [407]:
# check
x,y = next(iter(md.trn_dl))
compare_plot(x,y)



In [26]:
learn = Learner.from_model_data(VariationalAutoencoder(), md)

learn.opt_fn = torch.optim.Adadelta
learn.crit   = F.binary_cross_entropy
learn.reg_fn = vae_loss

In [27]:
learn.lr_find()
learn.sched.plot()


epoch      trn_loss   val_loss                                 
    0      0.264633   0.264753  

Looks like the Keras example trains for 50 epochs and has a latent-dimension side of 2. Oops. really? only 2? Heh, but the intermediate dimension is 512... I'm doing.. 32. Ehh. Let's see what happens. (note: I changed the interm dim later)


In [28]:
learn.fit(0.15, 50) # 50 1-epoch cycles at lr=0.15


epoch      trn_loss   val_loss                                 
    0      0.280115   0.276759  
    1      0.266848   0.266954                                 
    2      0.267701   0.264851                                 
    3      0.263734   0.264014                                 
    4      0.263267   0.263525                                 
    5      0.264164   0.263301                                 
    6      0.263229   0.263083                                 
    7      0.262408   0.262991                                 
    8      0.260605   0.262933                                 
    9      0.263024   0.26285                                  
    10     0.26458    0.262778                                  
    11     0.262448   0.26275                                   
    12     0.262691   0.262752                                 
    13     0.263262   0.262722                                  
    14     0.263677   0.26272                                   
    15     0.262111   0.262667                                  
    16     0.262238   0.262664                                  
    17     0.261204   0.262675                                  
    18     0.264375   0.26267                                   
    19     0.261294   0.262621                                  
    20     0.261497   0.262637                                  
    21     0.265735   0.26262                                   
    22     0.263462   0.262622                                  
    23     0.261259   0.262609                                  
    24     0.263913   0.2626                                    
    25     0.262395   0.262612                                  
    26     0.263638   0.262608                                  
    27     0.262055   0.262606                                  
    28     0.263085   0.262597                                  
    29     0.263261   0.262585                                  
    30     0.263414   0.262592                                  
    31     0.264423   0.262563                                  
    32     0.262674   0.262542                                  
    33     0.263953   0.262566                                  
    34     0.260652   0.262546                                  
    35     0.264851   0.262554                                  
    36     0.265179   0.262603                                  
    37     0.262004   0.262555                                  
    38     0.263015   0.262538                                  
    39     0.261999   0.262583                                  
    40     0.263641   0.262576                                  
    41     0.263273   0.262581                                  
    42     0.26369    0.262566                                  
    43     0.263175   0.262534                                  
    44     0.263495   0.262544                                  
    45     0.263779   0.262536                                  
    46     0.26333    0.262555                                  
    47     0.262153   0.262576                                  
    48     0.263693   0.262565                                  
    49     0.263763   0.262536                                  

Out[28]:
[0.2625357766866684]

In [29]:
learn.save('variational_autoencoder_50ep')

In [32]:
x,y = next(iter(learn.data.val_dl))
z = learn.predict()
compare_plot(x, z)



In [53]:
len(test_loader.dataset[0])


Out[53]:
2

In [54]:
enc_x.shape


Out[54]:
torch.Size([16, 32])

In [67]:
plt.figure(figsize=(6,6)); plt.style.use('classic');

for x,y in iter(test_loader):
    z, μσ, enc_x = learn.model(x)
    
    plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y); 
plt.colorbar();



In [70]:
plt.style.use('default')
compare_batch(x.detach(), z.detach())