Problem 3: Training and Fine-tuning on Fashion MNIST and MNIST

Training neural networks with a huge number of parameters on a small dataset greatly affects the networks' generalization ability, often resulting in overfitting. Therefore, more often in practice, one would fine-tune existing networks that are trained on a larger dataset by continuing training on a smaller dataset. To get familiar with the fine-tuning procedure, in this problem you need to train a model from scratch on Fashion MNIST dataset and then fine-tune it on MNIST dataset. Note that we are training models on these two toy datasets because of limited computational resources. In most cases, we train models on ImageNet and fine-tune them on smaller datasets.

  • Learning Objective: In Problem 2, you implemented a covolutional neural network to perform classification task in TensorFlow. In this part of the assignment, we will show you how to use TensorFlow to fine-tune a trained network on a different task.
  • Provided Codes: We provide the the dataset downloading and preprocessing codes, conv2d(), and fc() functions to build the model performing the fine-tuning task.
  • TODOs: Train a model from scratch on Fashion MNIST dataset and then fine-tune it on MNIST dataset. Both the training loss and the training accuracy need to be shown.

In [1]:
import numpy as np
import os.path as osp
import os
import subprocess

def download_data(download_root='data/', dataset='mnist'):
    if dataset == 'mnist':
        data_url = 'http://yann.lecun.com/exdb/mnist/'
    elif dataset == 'fashion_mnist':
        data_url = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
    else:
        raise ValueError('Please specify mnist or fashion_mnist.')

    data_dir = osp.join(download_root, dataset)
    if osp.exists(data_dir):
        print('The dataset was downloaded.')
        return
    else:
        os.mkdir(data_dir)

    keys = ['train-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz',
            'train-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz']

    for k in keys:
        url = (data_url+k).format(**locals())
        target_path = osp.join(data_dir, k)
        cmd = ['curl', url, '-o', target_path]
        print('Downloading ', k)
        subprocess.call(cmd)
        cmd = ['gzip', '-d', target_path]
        print('Unzip ', k)
        subprocess.call(cmd)


def load_data(data_dir):
    num_train = 60000
    num_test = 10000

    def load_file(filename, num, shape):
        fd = open(osp.join(data_dir, filename))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        return loaded[num:].reshape(shape).astype(np.float)

    train_image = load_file('train-images-idx3-ubyte', 16, (num_train, 28, 28, 1))
    train_label = load_file('train-labels-idx1-ubyte', 8, num_train)
    test_image = load_file('t10k-images-idx3-ubyte', 16, (num_test, 28, 28, 1))
    test_label = load_file('t10k-labels-idx1-ubyte', 8, num_test)
    return train_image, train_label, test_image, test_label

In [2]:
# Download MNIST and Fashion MNIST
download_data(dataset='mnist')
download_data(dataset='fashion_mnist')


('Downloading ', 'train-images-idx3-ubyte.gz')
('Unzip ', 'train-images-idx3-ubyte.gz')
('Downloading ', 't10k-images-idx3-ubyte.gz')
('Unzip ', 't10k-images-idx3-ubyte.gz')
('Downloading ', 'train-labels-idx1-ubyte.gz')
('Unzip ', 'train-labels-idx1-ubyte.gz')
('Downloading ', 't10k-labels-idx1-ubyte.gz')
('Unzip ', 't10k-labels-idx1-ubyte.gz')
('Downloading ', 'train-images-idx3-ubyte.gz')
('Unzip ', 'train-images-idx3-ubyte.gz')
('Downloading ', 't10k-images-idx3-ubyte.gz')
('Unzip ', 't10k-images-idx3-ubyte.gz')
('Downloading ', 'train-labels-idx1-ubyte.gz')
('Unzip ', 'train-labels-idx1-ubyte.gz')
('Downloading ', 't10k-labels-idx1-ubyte.gz')
('Unzip ', 't10k-labels-idx1-ubyte.gz')

In [3]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import matplotlib.pyplot as plt
%matplotlib inline

def conv2d(input, output_shape, k=4, s=2, name='conv2d'):
    with tf.variable_scope(name):
        return slim.conv2d(input, output_shape, [k, k], stride=s)


def fc(input, output_shape, act_fn=tf.nn.relu, name='fc'):
    with tf.variable_scope(name):
        return slim.fully_connected(input, output_shape, activation_fn=act_fn)


def train(batch_size=100, num_epoch=5, learning_rate=1e-5,
          num_train=60000, num_test=10000):
    sess = tf.InteractiveSession()
    
    # Build the model
    X = tf.placeholder(tf.float32, [None, 28, 28, 1])
    Y = tf.placeholder(tf.int64, [None])
    labels = tf.one_hot(Y, 10)
    _ = conv2d(X, 32, name='conv1')
    _ = conv2d(_, 64, name='conv2')
    _ = conv2d(_, 256, name='conv3')
    _ = tf.reshape(_, [-1, np.prod(_.get_shape().as_list()[1:])])
    _ = fc(_, 256, name='fc1')
    logits = fc(_, 10, act_fn=None, name='fc2')

    loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    loss_op = tf.reduce_mean(loss)

    global_step = tf.Variable(0, trainable=False)
    learning_rate = 1e-4
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)

    predict = tf.argmax(logits, 1)
    correct = tf.equal(predict, Y)
    accuracy_op = tf.reduce_mean(tf.cast(correct, tf.float32))

    sess.run(tf.global_variables_initializer())

    total_loss = []
    total_accuracy = []

    print('\033[93mTrain Fashion MNIST\033[0m')
    X_train, Y_train, X_test, Y_test = load_data('data/fashion_mnist')
    #############################################################################
    # TODO: Train the model on Fashion MNIST from scratch                       #
    # and then fine-tune it on MNIST                                            #
    # Collect the training loss and the training accuracy                       #
    # fetched from each iteration                                               #
    # After the two stages of the training, the length of                       #
    # total_loss and total_accuracy shuold be                                   #
    # 2 *num_epoch * num_train / batch_size = 2 * 5 * 60000 / 100 = 6000        #
    #############################################################################
    # Train the model on Fashion MNIST
    for epoch in range(num_epoch):
        for i in range(num_train // batch_size):
            X_ = X_train[i * batch_size:(i + 1) * batch_size][:]
            Y_ = Y_train[i * batch_size:(i + 1) * batch_size]
            
            feed_dict = {X : X_, Y : Y_}
            fetches = [train_op, loss_op, accuracy_op]
            _, loss, accuracy = sess.run(fetches, feed_dict=feed_dict)
            total_loss.append(loss)
            total_accuracy.append(accuracy)
            
        print('[Epoch {}] loss: {}, accuracy: {}'.format(epoch, loss, accuracy))


    # Train the model on MNIST
    print('\033[93mTrain MNIST\033[0m')
    X_train, Y_train, X_test, Y_test = load_data('data/mnist')
    for epoch in range(num_epoch):
        for i in range(num_train // batch_size):
            X_ = X_train[i * batch_size:(i + 1) * batch_size][:]
            Y_ = Y_train[i * batch_size:(i + 1) * batch_size]
            
            feed_dict = {X : X_, Y : Y_}
            fetches = [train_op, loss_op, accuracy_op]
            _, loss, accuracy = sess.run(fetches, feed_dict=feed_dict)
            total_loss.append(loss)
            total_accuracy.append(accuracy)
        print('[Epoch {}] loss: {}, accuracy: {}'.format(epoch, loss, accuracy))

    #############################################################################
    #                             END OF YOUR CODE                              #
    #############################################################################
    return total_loss, total_accuracy

In [4]:
loss, accuracy = train()


Train Fashion MNIST
[Epoch 0] loss: 0.312357932329, accuracy: 0.879999995232
[Epoch 1] loss: 0.236116871238, accuracy: 0.930000007153
[Epoch 2] loss: 0.20353718102, accuracy: 0.939999997616
[Epoch 3] loss: 0.163721293211, accuracy: 0.97000002861
[Epoch 4] loss: 0.163444876671, accuracy: 0.939999997616
Train MNIST
[Epoch 0] loss: 0.260136812925, accuracy: 0.97000002861
[Epoch 1] loss: 0.240254819393, accuracy: 0.980000019073
[Epoch 2] loss: 0.220106124878, accuracy: 0.990000009537
[Epoch 3] loss: 0.182692378759, accuracy: 0.990000009537
[Epoch 4] loss: 0.144816577435, accuracy: 0.990000009537

In [5]:
# Plot the training loss and the training accuracy
plt.plot(loss)
plt.title('training loss')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.show()    

plt.plot(accuracy)
plt.title('training accuracy')
plt.xlabel('iteration')
plt.ylabel('accuracy')
plt.show()