03 - Dropout regularization from scratch


In [1]:
from __future__ import print_function
import mxnet as mx
import numpy as np
from mxnet import gluon
from tqdm import tqdm
mx.random.seed(1)

Context


In [2]:
ctx = mx.cpu()

MNIST Dataset


In [3]:
mnist = mx.test_utils.get_mnist()
batch_size = 64

In [4]:
def transform(data, label):
    return data.astype(np.float32) / 255, label.astype(np.float32)

In [5]:
train_data = gluon.data.DataLoader(dataset=gluon.data.vision.MNIST(train=True, transform=transform),
                                   batch_size=batch_size,
                                   shuffle=True)
test_data = gluon.data.DataLoader(dataset=gluon.data.vision.MNIST(train=False, transform=transform),
                                  batch_size=batch_size, 
                                  shuffle=False)

Defining network variables


In [6]:
W1 = mx.nd.random_normal(shape=(784, 256), ctx=ctx) *.01
b1 = mx.nd.random_normal(shape=256, ctx=ctx) * .01

W2 = mx.nd.random_normal(shape=(256, 128), ctx=ctx) *.01
b2 = mx.nd.random_normal(shape=128, ctx=ctx) * .01

W3 = mx.nd.random_normal(shape=(128, 10), ctx=ctx) *.01
b3 = mx.nd.random_normal(shape=10, ctx=ctx) *.01

params = [W1, b1, W2, b2, W3, b3]

In [7]:
for param in params:
    param.attach_grad()

ReLU


In [8]:
def relu(X):
    return mx.nd.maximum(X, 0)

Dropout


In [9]:
def dropout(X, drop_probability):
    keep_probability = 1 - drop_probability
    mask = mx.nd.random_uniform(low=0,
                                high=1.0,
                                shape=X.shape,
                                ctx=X.context) < keep_probability
    if keep_probability > 0.0:
        scale = (1 / keep_probability)
    else:
        scale = 0.0
    return mask * X * scale

Dropout: an example


In [10]:
A = mx.nd.arange(20).reshape((5,4))
dropout(A, 0.0)


Out[10]:
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]
 [16. 17. 18. 19.]]
<NDArray 5x4 @cpu(0)>

In [11]:
dropout(A, 0.5)


Out[11]:
[[ 0.  2.  4.  6.]
 [ 8.  0. 12. 14.]
 [16. 18. 20. 22.]
 [24.  0. 28. 30.]
 [32. 34.  0. 38.]]
<NDArray 5x4 @cpu(0)>

In [12]:
dropout(A, 1.0)


Out[12]:
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
<NDArray 5x4 @cpu(0)>

Softmax


In [13]:
def softmax(y_linear):
    exp = mxnd.exp(y_linear - mx.nd.max(y_linear))
    partition = mx.nd.nansum(data=exp,
                             axis=0,
                             exclude=True).reshape((-1,1))
    return exp / partition

Softmax Cross-entropy Losss Function


In [14]:
def softmax_cross_entropy(yhat_linear, y):
    return - mx.nd.nansum(y * mx.nd.log_softmax(yhat_linear),
                          axis=0,
                          exclude=True)

Network


In [15]:
def net(X, drop_prob=0.0):
    #######################
    #  Compute the first hidden layer
    #######################
    h1_linear = mx.nd.dot(X, W1) + b1
    h1 = relu(h1_linear)
    h1 = dropout(h1, drop_prob)

    #######################
    #  Compute the second hidden layer
    #######################
    h2_linear = mx.nd.dot(h1, W2) + b2
    h2 = relu(h2_linear)
    h2 = dropout(h2, drop_prob)

    #######################
    #  Compute the output layer.
    #  We will omit the softmax function here
    #  because it will be applied
    #  in the softmax_cross_entropy loss
    #######################
    yhat_linear = mx.nd.dot(h2, W3) + b3
    return yhat_linear

Stochastic Gradient Descent


In [16]:
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

Evaluation


In [17]:
def evaluate_accuracy(data_iterator, net):
    numerator = 0.
    denominator = 0.
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx).reshape((-1, 784))
        label = label.as_in_context(ctx)
        output = net(data)
        predictions = mx.nd.argmax(output,
                                   axis=1)
        numerator += mx.nd.sum(predictions == label)
        denominator += data.shape[0]
    return (numerator / denominator).asscalar()

Training


In [18]:
epochs = 10
moving_loss = 0.
learning_rate = .001

In [19]:
for e in tqdm(range(epochs)):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx).reshape((-1,784))
        label = label.as_in_context(ctx)
        label_one_hot = mx.nd.one_hot(label, 10)
        with mx.autograd.record():
            ################################
            #   Drop out 50% of hidden activations on the forward pass
            ################################
            output = net(data, drop_prob=.5)
            loss = softmax_cross_entropy(output, label_one_hot)
        loss.backward()
        SGD(params, learning_rate)

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if i == 0:
            moving_loss = mx.nd.mean(loss).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * mx.nd.mean(loss).asscalar()

    test_accuracy = evaluate_accuracy(test_data, net)
    train_accuracy = evaluate_accuracy(train_data, net)
    print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" % (e, moving_loss, train_accuracy, test_accuracy))


  0%|                                                                                                                                                                                  | 0/10 [00:00<?, ?it/s]
Epoch 0. Loss: 0.732930123066723, Train_acc 0.85461664, Test_acc 0.8601
 10%|█████████████████                                                                                                                                                         | 1/10 [00:40<06:04, 40.52s/it]
Epoch 1. Loss: 0.37699336618958423, Train_acc 0.92035, Test_acc 0.9219
 20%|██████████████████████████████████                                                                                                                                        | 2/10 [01:21<05:26, 40.85s/it]
Epoch 2. Loss: 0.28828520928401863, Train_acc 0.9418667, Test_acc 0.9396
 30%|███████████████████████████████████████████████████                                                                                                                       | 3/10 [01:55<04:29, 38.49s/it]
Epoch 3. Loss: 0.23723796996234844, Train_acc 0.95595, Test_acc 0.9539
 40%|████████████████████████████████████████████████████████████████████                                                                                                      | 4/10 [02:31<03:47, 37.87s/it]
Epoch 4. Loss: 0.20619299122066936, Train_acc 0.96168333, Test_acc 0.9567
 50%|█████████████████████████████████████████████████████████████████████████████████████                                                                                     | 5/10 [03:08<03:08, 37.76s/it]
Epoch 5. Loss: 0.1830389900122309, Train_acc 0.9690667, Test_acc 0.9655
 60%|██████████████████████████████████████████████████████████████████████████████████████████████████████                                                                    | 6/10 [03:43<02:29, 37.28s/it]
Epoch 6. Loss: 0.17125523925013506, Train_acc 0.97241664, Test_acc 0.9659
 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                   | 7/10 [04:16<01:49, 36.59s/it]
Epoch 7. Loss: 0.15506527236897544, Train_acc 0.97671664, Test_acc 0.9703
 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                  | 8/10 [04:49<01:12, 36.14s/it]
Epoch 8. Loss: 0.1459825592868318, Train_acc 0.97795, Test_acc 0.9712
 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                 | 9/10 [05:22<00:35, 35.81s/it]
Epoch 9. Loss: 0.12992356050813567, Train_acc 0.97943336, Test_acc 0.9727
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [05:54<00:00, 35.48s/it]

In [20]:
train_accuracy


Out[20]:
0.97943336

In [21]:
test_accuracy


Out[21]:
0.9727