Training an image classifier

We will do the following steps in order:

  1. Load and normalizing the MNIST training and test datasets using torchvision
  2. Define a neural network
  3. Define a loss function
  4. Train the network on the training data
  5. Test the network on the test data

  6. Loading and normalizing MNIST ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using torchvision, it’s extremely easy to load MNIST.

import torch import torchvision import torchvision.transforms as transforms

The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1]. 60,000 training samples and 10,000 test samples


In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
  1. Define a Neural Network ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Three hidden layers, input size = height * width of the image, output size = the number of classes (which is 10 in the case of MNIST)

Use the base class: nn.Module

The nn.Module mainly takes care of storing the paramters of the neural network.


In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        
        # flatten image
        x = x[:, 0, ...].view(-1, 28*28)
        
        # feed layer 1
        out_layer1 = self.fc1(x)
        out_layer1 = F.relu(out_layer1)
        
        # feed layer 2
        out_layer2 = self.fc2(out_layer1)
        out_layer2 = F.relu(out_layer2)
        
        # feed layer 3
        out_layer3 = self.fc3(out_layer2)
        
        return out_layer3


net = Net()
  1. Define a Loss function and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let's use a Classification Cross-Entropy loss and SGD with momentum.

In [5]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  1. Train the network ^^^^^^^^^^^^^^^^^^^^

This is when things start to get interesting. We simply have to loop over our data iterator, and feed the inputs to the network and optimize.


In [12]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 99 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')


[1,     1] loss: 0.000
[1,   100] loss: 0.109
[1,   199] loss: 0.132
[1,   298] loss: 0.149
[1,   397] loss: 0.159
[1,   496] loss: 0.145
[1,   595] loss: 0.102
[1,   694] loss: 0.113
[1,   793] loss: 0.163
[1,   892] loss: 0.147
[1,   991] loss: 0.101
[1,  1090] loss: 0.069
[1,  1189] loss: 0.205
[1,  1288] loss: 0.128
[1,  1387] loss: 0.085
[1,  1486] loss: 0.082
[1,  1585] loss: 0.100
[1,  1684] loss: 0.149
[1,  1783] loss: 0.131
[1,  1882] loss: 0.161
[1,  1981] loss: 0.118
[1,  2080] loss: 0.147
[1,  2179] loss: 0.127
[1,  2278] loss: 0.120
[1,  2377] loss: 0.085
[1,  2476] loss: 0.053
[1,  2575] loss: 0.141
[1,  2674] loss: 0.115
[1,  2773] loss: 0.117
[1,  2872] loss: 0.113
[1,  2971] loss: 0.086
[1,  3070] loss: 0.132
[1,  3169] loss: 0.138
[1,  3268] loss: 0.120
[1,  3367] loss: 0.150
[1,  3466] loss: 0.127
[1,  3565] loss: 0.186
[1,  3664] loss: 0.083
[1,  3763] loss: 0.085
[1,  3862] loss: 0.106
[1,  3961] loss: 0.096
[1,  4060] loss: 0.120
[1,  4159] loss: 0.130
[1,  4258] loss: 0.082
[1,  4357] loss: 0.145
[1,  4456] loss: 0.129
[1,  4555] loss: 0.096
[1,  4654] loss: 0.074
[1,  4753] loss: 0.205
[1,  4852] loss: 0.088
[1,  4951] loss: 0.146
[1,  5050] loss: 0.139
[1,  5149] loss: 0.085
[1,  5248] loss: 0.109
[1,  5347] loss: 0.111
[1,  5446] loss: 0.111
[1,  5545] loss: 0.089
[1,  5644] loss: 0.089
[1,  5743] loss: 0.080
[1,  5842] loss: 0.122
[1,  5941] loss: 0.103
[1,  6040] loss: 0.074
[1,  6139] loss: 0.100
[1,  6238] loss: 0.202
[1,  6337] loss: 0.135
[1,  6436] loss: 0.133
[1,  6535] loss: 0.119
[1,  6634] loss: 0.115
[1,  6733] loss: 0.149
[1,  6832] loss: 0.151
[1,  6931] loss: 0.113
[1,  7030] loss: 0.152
[1,  7129] loss: 0.141
[1,  7228] loss: 0.153
[1,  7327] loss: 0.081
[1,  7426] loss: 0.082
[1,  7525] loss: 0.125
[1,  7624] loss: 0.118
[1,  7723] loss: 0.130
[1,  7822] loss: 0.110
[1,  7921] loss: 0.149
[1,  8020] loss: 0.099
[1,  8119] loss: 0.078
[1,  8218] loss: 0.090
[1,  8317] loss: 0.116
[1,  8416] loss: 0.066
[1,  8515] loss: 0.078
[1,  8614] loss: 0.112
[1,  8713] loss: 0.082
[1,  8812] loss: 0.124
[1,  8911] loss: 0.107
[1,  9010] loss: 0.121
[1,  9109] loss: 0.080
[1,  9208] loss: 0.099
[1,  9307] loss: 0.120
[1,  9406] loss: 0.091
[1,  9505] loss: 0.126
[1,  9604] loss: 0.196
[1,  9703] loss: 0.087
[1,  9802] loss: 0.077
[1,  9901] loss: 0.074
[1, 10000] loss: 0.125
[1, 10099] loss: 0.111
[1, 10198] loss: 0.102
[1, 10297] loss: 0.175
[1, 10396] loss: 0.072
[1, 10495] loss: 0.152
[1, 10594] loss: 0.144
[1, 10693] loss: 0.107
[1, 10792] loss: 0.103
[1, 10891] loss: 0.103
[1, 10990] loss: 0.112
[1, 11089] loss: 0.141
[1, 11188] loss: 0.076
[1, 11287] loss: 0.080
[1, 11386] loss: 0.120
[1, 11485] loss: 0.111
[1, 11584] loss: 0.118
[1, 11683] loss: 0.094
[1, 11782] loss: 0.112
[1, 11881] loss: 0.095
[1, 11980] loss: 0.081
[1, 12079] loss: 0.091
[1, 12178] loss: 0.144
[1, 12277] loss: 0.091
[1, 12376] loss: 0.073
[1, 12475] loss: 0.068
[1, 12574] loss: 0.119
[1, 12673] loss: 0.085
[1, 12772] loss: 0.107
[1, 12871] loss: 0.120
[1, 12970] loss: 0.143
[1, 13069] loss: 0.162
[1, 13168] loss: 0.075
[1, 13267] loss: 0.076
[1, 13366] loss: 0.111
[1, 13465] loss: 0.132
[1, 13564] loss: 0.083
[1, 13663] loss: 0.075
[1, 13762] loss: 0.140
[1, 13861] loss: 0.067
[1, 13960] loss: 0.109
[1, 14059] loss: 0.134
[1, 14158] loss: 0.121
[1, 14257] loss: 0.120
[1, 14356] loss: 0.093
[1, 14455] loss: 0.130
[1, 14554] loss: 0.062
[1, 14653] loss: 0.130
[1, 14752] loss: 0.137
[1, 14851] loss: 0.144
[1, 14950] loss: 0.189
[2,     1] loss: 0.000
[2,   100] loss: 0.091
[2,   199] loss: 0.127
[2,   298] loss: 0.090
[2,   397] loss: 0.094
[2,   496] loss: 0.085
[2,   595] loss: 0.108
[2,   694] loss: 0.085
[2,   793] loss: 0.096
[2,   892] loss: 0.084
[2,   991] loss: 0.162
[2,  1090] loss: 0.063
[2,  1189] loss: 0.176
[2,  1288] loss: 0.089
[2,  1387] loss: 0.087
[2,  1486] loss: 0.119
[2,  1585] loss: 0.129
[2,  1684] loss: 0.112
[2,  1783] loss: 0.129
[2,  1882] loss: 0.065
[2,  1981] loss: 0.108
[2,  2080] loss: 0.088
[2,  2179] loss: 0.072
[2,  2278] loss: 0.129
[2,  2377] loss: 0.125
[2,  2476] loss: 0.076
[2,  2575] loss: 0.086
[2,  2674] loss: 0.094
[2,  2773] loss: 0.094
[2,  2872] loss: 0.123
[2,  2971] loss: 0.089
[2,  3070] loss: 0.129
[2,  3169] loss: 0.067
[2,  3268] loss: 0.081
[2,  3367] loss: 0.075
[2,  3466] loss: 0.074
[2,  3565] loss: 0.054
[2,  3664] loss: 0.089
[2,  3763] loss: 0.079
[2,  3862] loss: 0.093
[2,  3961] loss: 0.071
[2,  4060] loss: 0.094
[2,  4159] loss: 0.061
[2,  4258] loss: 0.092
[2,  4357] loss: 0.045
[2,  4456] loss: 0.126
[2,  4555] loss: 0.138
[2,  4654] loss: 0.129
[2,  4753] loss: 0.094
[2,  4852] loss: 0.076
[2,  4951] loss: 0.077
[2,  5050] loss: 0.079
[2,  5149] loss: 0.080
[2,  5248] loss: 0.091
[2,  5347] loss: 0.117
[2,  5446] loss: 0.125
[2,  5545] loss: 0.138
[2,  5644] loss: 0.090
[2,  5743] loss: 0.101
[2,  5842] loss: 0.067
[2,  5941] loss: 0.089
[2,  6040] loss: 0.099
[2,  6139] loss: 0.081
[2,  6238] loss: 0.126
[2,  6337] loss: 0.083
[2,  6436] loss: 0.083
[2,  6535] loss: 0.082
[2,  6634] loss: 0.094
[2,  6733] loss: 0.054
[2,  6832] loss: 0.097
[2,  6931] loss: 0.070
[2,  7030] loss: 0.125
[2,  7129] loss: 0.113
[2,  7228] loss: 0.108
[2,  7327] loss: 0.074
[2,  7426] loss: 0.095
[2,  7525] loss: 0.089
[2,  7624] loss: 0.139
[2,  7723] loss: 0.090
[2,  7822] loss: 0.103
[2,  7921] loss: 0.111
[2,  8020] loss: 0.067
[2,  8119] loss: 0.039
[2,  8218] loss: 0.086
[2,  8317] loss: 0.087
[2,  8416] loss: 0.076
[2,  8515] loss: 0.087
[2,  8614] loss: 0.121
[2,  8713] loss: 0.095
[2,  8812] loss: 0.054
[2,  8911] loss: 0.140
[2,  9010] loss: 0.076
[2,  9109] loss: 0.109
[2,  9208] loss: 0.137
[2,  9307] loss: 0.087
[2,  9406] loss: 0.066
[2,  9505] loss: 0.059
[2,  9604] loss: 0.079
[2,  9703] loss: 0.087
[2,  9802] loss: 0.074
[2,  9901] loss: 0.077
[2, 10000] loss: 0.182
[2, 10099] loss: 0.083
[2, 10198] loss: 0.108
[2, 10297] loss: 0.092
[2, 10396] loss: 0.054
[2, 10495] loss: 0.123
[2, 10594] loss: 0.073
[2, 10693] loss: 0.079
[2, 10792] loss: 0.083
[2, 10891] loss: 0.061
[2, 10990] loss: 0.111
[2, 11089] loss: 0.059
[2, 11188] loss: 0.106
[2, 11287] loss: 0.110
[2, 11386] loss: 0.115
[2, 11485] loss: 0.054
[2, 11584] loss: 0.074
[2, 11683] loss: 0.093
[2, 11782] loss: 0.135
[2, 11881] loss: 0.078
[2, 11980] loss: 0.094
[2, 12079] loss: 0.052
[2, 12178] loss: 0.091
[2, 12277] loss: 0.106
[2, 12376] loss: 0.090
[2, 12475] loss: 0.089
[2, 12574] loss: 0.089
[2, 12673] loss: 0.110
[2, 12772] loss: 0.090
[2, 12871] loss: 0.101
[2, 12970] loss: 0.088
[2, 13069] loss: 0.056
[2, 13168] loss: 0.098
[2, 13267] loss: 0.159
[2, 13366] loss: 0.085
[2, 13465] loss: 0.076
[2, 13564] loss: 0.164
[2, 13663] loss: 0.101
[2, 13762] loss: 0.071
[2, 13861] loss: 0.100
[2, 13960] loss: 0.079
[2, 14059] loss: 0.105
[2, 14158] loss: 0.084
[2, 14257] loss: 0.090
[2, 14356] loss: 0.091
[2, 14455] loss: 0.112
[2, 14554] loss: 0.058
[2, 14653] loss: 0.070
[2, 14752] loss: 0.137
[2, 14851] loss: 0.130
[2, 14950] loss: 0.107
Finished Training
  1. Test the network on the test data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We have trained the network for 2 passes over the training dataset. But we need to check if the network has learnt anything at all.

We will check this by predicting the class label that the neural network outputs, and checking it against the ground-truth. If the prediction is correct, we add the sample to the list of correct predictions.

Okay, first step. Let us display an image from the test set to get familiar.


In [7]:
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

Performance on the test dataset.


In [11]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))


Accuracy of the network on the 10000 test images: 96 %

Plot images:


In [10]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
predictions = net(images)

_, predicted = torch.max(predictions.data, 1)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % predicted[j].item() for j in range(4)))


tensor([4, 0, 1, 3])
    4     0     1     3

In [ ]: