This notebook demonstrates how to build and train a convolutional autoencoder.
Autoencoders consist of two models: an encoder and a decoder.
In this notebook we'll build an autoencoder to recreate MNIST digits. This notebook demonstrates this process on the MNIST dataset. The following animation shows a series of images produced by the generator as it was trained for 100 epochs. The images increasingly resemble hand written digits as the autoencoder learns to reconstruct the original images.
In [1]:
from __future__ import absolute_import, division, print_function
import glob
import imageio
import os
import PIL
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from IPython import display
Next, we'll define some of the environment variables we'll use in this notebook. Note that we are setting the EMBED_DIM
to be 64. This is the dimension of the latent space for our autoencoder.
In [2]:
np.random.seed(1)
tf.random.set_seed(1)
BATCH_SIZE = 128
BUFFER_SIZE = 60000
EPOCHS = 60
LR = 1e-2
EMBED_DIM = 64 # intermediate_dim
In [3]:
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
In [4]:
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
Next, we define our input pipeline using tf.data
. The pipeline below reads in train_images
as tensor slices and then shuffles and batches the examples for training.
In [5]:
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(BATCH_SIZE*4)
In [6]:
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')
test_images = (test_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
Both our encoder and decoder models will be defined using the Keras Sequential API.
The encoder uses tf.keras.layers.Conv2D
layers to map the image into a lower-dimensional latent space. We will start with an image of size 28x28x1 and then use convolution layers to map into a final Dense
layer.
Exercise. Complete the code below to create the CNN-based encoder model. Your model should have input_shape
to be 28x28x1 and end with a final Dense
layer the size of embed_dim
.
In [7]:
#TODO 1.
def make_encoder(embed_dim):
model = tf.keras.Sequential(name="encoder")
# TODO: Your code goes here.
assert model.output_shape == (None, embed_dim)
return model
The decoder uses tf.keras.layers.Conv2DTranspose
(upsampling) layers to produce an image from the latent space. We will start with a Dense
layer with the same input shape as embed_dim
, then upsample several times until you reach the desired image size of 28x28x1.
Exercise. Complete the code below to create the decoder model. Start with a Dense
layer that takes as input a tensor of size embed_dim
. Use tf.keras.layers.Conv2DTranspose
over multiple layers to upsample so that the final layer has shape 28x28x1 (the shape of our original MNIST digits).
Hint: Experiment with using BatchNormalization
or different activation functions like LeakyReLU
.
In [8]:
#TODO 1.
def make_decoder(embed_dim):
model = tf.keras.Sequential(name="decoder")
# TODO: Your code goes here.
assert model.output_shape == (None, 28, 28, 1)
return model
Finally, we stitch the encoder and decoder models together to create our autoencoder.
In [9]:
ae_model = tf.keras.models.Sequential([make_encoder(EMBED_DIM), make_decoder(EMBED_DIM)])
Using .summary()
we can have a high-level summary of the full autoencoder model as well as the individual encoder and decoder. Note how the shapes of the tensors mirror each other as data is passed through the encoder and then the decoder.
In [10]:
ae_model.summary()
In [11]:
make_encoder(EMBED_DIM).summary()
In [12]:
make_decoder(EMBED_DIM).summary()
Next, we define the loss for our autoencoder model. The loss we will use is the reconstruction error. This loss is similar to the MSE loss we've commonly use for regression. Here we are applying this error pixel-wise to compare the original MNIST image and the image reconstructed from the decoder.
In [13]:
#TODO 2.
def loss(model, original):
reconstruction_error = # TODO: Your code goes here.
return reconstruction_error
In [14]:
optimizer = tf.keras.optimizers.SGD(lr=LR)
In [15]:
checkpoint_dir = "./ae_training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=ae_model)
Exercise.
Complete the code below to define the training loop for our autoencoder. Notice the use of tf.function
below. This annotation causes the function train_step
to be "compiled". The train_step
function takes as input a batch of images and passes them through the ae_model
. The gradient is then computed on the loss against the ae_model
output and the original image. In the code below, you should
ae_gradients
. This is the gradient of the autoencoder loss with respect to the variables of the ae_model
.gradient_variables
by assigning each ae_gradient computed above to it's respective training variable.optimizer
In [16]:
#TODO 3.
@tf.function
def train_step(images):
with tf.GradientTape() as tape:
ae_gradients = # TODO: Your code goes here.
gradient_variables = # TODO: Your code goes here.
# TODO: Your code goes here.
We use the train_step
function above to define training of our autoencoder. Note here, the train
function takes as argument the tf.data
dataset and the number of epochs for training.
In [17]:
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(ae_model,
epoch + 1,
test_images[:16, :, :, :])
# Save the model every 5 epochs
if (epoch + 1) % 5 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(
epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(ae_model,
epochs,
test_images[:16, :, :, :])
Generate and save images. We'll use a small helper function to generate images and save them.
In [18]:
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
pixels = predictions[i, :, :] * 127.5 + 127.5
pixels = np.array(pixels, dtype='float')
pixels = pixels.reshape((28,28))
plt.imshow(pixels, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
Let's see how our model performs before any training. We'll take as input the first 16 digits of the MNIST test set. Right now they just look like random noise.
In [19]:
generate_and_save_images(ae_model, 4, test_images[:16, :, :, :])
Call the train()
method defined above to train the autoencoder model.
We'll print the resulting images as training progresses. At the beginning of the training, the decoded images look like random noise. As training progresses, the model outputs will look increasingly better. After about 50 epochs, they resemble MNIST digits. This may take about one or two minutes / epoch
In [ ]:
#TODO 4.
# TODO: Your code goes here.
In [ ]:
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('./ae_images/image_at_epoch_{:04d}.png'.format(epoch_no))
In [ ]:
display_image(EPOCHS)
In [ ]:
anim_file = 'autoencoder.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('./ae_images/image*.png')
filenames = sorted(filenames)
last = -1
for i,filename in enumerate(filenames):
frame = 2*(i**0.5)
if round(frame) > round(last):
last = frame
else:
continue
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
import IPython
if IPython.version_info > (6,2,0,''):
display.Image(filename=anim_file)
Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License