In this lab, you will use a single layer Softmax to classify handwritten digits from the MNIST database.
In [ ]:
!conda install -y torchvision
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import matplotlib.pylab as plt
import numpy as np
Use the following function to visualize data:
In [ ]:
def show_data(data_sample):
plt.imshow(data_sample[0].numpy().reshape(28,28),cmap='gray')
#print(data_sample[1].item())
plt.title('y= '+ str(data_sample[1].item()))
Load the training dataset by setting the parameters train
to True
and convert it to a tensor by placing a transform object in the argument transform
.
In [ ]:
train_dataset=dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_dataset
Load the testing dataset by setting the parameters train False
and convert it to a tensor by placing a transform object in the argument transform
.
In [ ]:
validation_dataset=dsets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
validation_dataset
Note that the data type is long:
In [ ]:
train_dataset[0][1].type()
Each element in the rectangular tensor corresponds to a number that represents a pixel intensity as demonstrated by the following image:
Print out the third label:
In [ ]:
train_dataset[3][1]
Plot the 3rd sample:
In [ ]:
show_data(train_dataset[3])
You see its a 1. Now, plot the second sample:
In [ ]:
show_data(train_dataset[2])
The Softmax function requires vector inputs. If you see the vector shape, you'll note it's 28x28.
In [ ]:
train_dataset[0][0].shape
Flatten the tensor as shown in this image:
The size of the tensor is now 784.
Set the input size and output size.
In [ ]:
input_dim=28*28
output_dim=10
input_dim
In [ ]:
Double-click here for the solution.
View the size of the model parameters:
In [ ]:
print('W:',list(model.parameters())[0].size())
print('b',list(model.parameters())[1].size())
Cover the model parameters for each class to a rectangular grid:
Plot the model parameters for each class:
Loss function:
In [ ]:
criterion=nn.CrossEntropyLoss()
Optimizer class:
In [ ]:
learning_rate=0.1
optimizer=torch.optim.SGD(model.parameters(), lr=learning_rate)
Define the dataset loader:
In [ ]:
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=100)
validation_loader=torch.utils.data.DataLoader(dataset=validation_dataset,batch_size=5000)
Train the model and determine validation accuracy:
In [ ]:
n_epochs=10
loss_list=[]
accuracy_list=[]
N_test=len(validation_dataset)
#n_epochs
for epoch in range(n_epochs):
for x, y in train_loader:
#clear gradient
optimizer.zero_grad()
#make a prediction
z=model(x.view(-1,28*28))
# calculate loss
loss=criterion(z,y)
# calculate gradients of parameters
loss.backward()
# update parameters
optimizer.step()
correct=0
#perform a prediction on the validation data
for x_test, y_test in validation_loader:
z=model(x_test.view(-1,28*28))
_,yhat=torch.max(z.data,1)
correct+=(yhat==y_test).sum().item()
accuracy=correct/N_test
accuracy_list.append(accuracy)
loss_list.append(loss.data)
accuracy_list.append(accuracy)
Plot the loss and accuracy on the validation data:
In [ ]:
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.plot(loss_list,color=color)
ax1.set_xlabel('epoch',color=color)
ax1.set_ylabel('total loss',color=color)
ax1.tick_params(axis='y', color=color)
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)
ax2.plot( accuracy_list, color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
Plot the first five misclassified samples:
In [ ]:
count=0
for x,y in validation_dataset:
z=model(x.reshape(-1,28*28))
_,yhat=torch.max(z,1)
if yhat!=y:
show_data((x,y))
plt.show()
print("yhat:",yhat)
count+=1
if count>=5:
break
Joseph Santarcangelo has a PhD in Electrical Engineering. His research focused on using machine learning, signal processing, and computer vision to determine how videos impact human cognition.
Other contributors: Michelle Carey, Mavis Zhou
Copyright © 2018 cognitiveclass.ai. This notebook and its source code are released under the terms of the MIT License.