Data Loading Tutorial

Loading data is one crucial step in the deep learning pipeline. PyTorch makes it easy to write custom data loaders for your particular dataset. In this notebook, we're going to download and load cifar-10 dataset and return a torch "tensor" as the image and the label.

Getting the dataset

Head over to CIFAR-10 dataset homepage and download the "python" version mentioned on the page. Extract the tar.gz archive in your home directory.

Exploring the dataset

CIFAR-10 dataset is divided into 5 training batches and one test batch. In order to train effectively we need to use all the training data. The datasets themselves are stored in a "pickle" format so we have to "unpickle" them first:

In [2]:
import sys, os 
import pickle 

def unpickle(fname):
    with open(fname, 'rb') as f:
        Dict = pickle.load(f, encoding='bytes')
    return Dict

Let's test if we actually can load the data after the unpickling. For that we're going to load one the data batches and see if we can find the data and labels from it:

In [3]:
Dic = unpickle(os.path.join('/home/akulshr','cifar-10-batches-py', 'data_batch_1'))

(10000, 3072)

We notice that the data is a numpy array of 1000x3072 which means that the data batch contains 10000 images of size 3072 (32x32x3). The labels are 1000x1 list in which the ith element corresponds to the correct label for the ith image. This is how data is in the other training batches. Whenever you've got a new dataset always try to load one or two images(or a data batch) like in this case to get a feel of how the data is.

Loading data like this can be tedious and time consuming. Let's write a helper function which will return the data and the labels:

In [4]:
def load_data(batch):
    print ("Loading batch:{}".format(batch))
    return unpickle(batch)

This little helper function will call our "unpickling" function to unpickle the data and return the unpicked data back to us.In datasets where there is only images, you can use PIL.Image. Armed with this let's write a dataloader for PyTorch.

DataLoading in PyTorch

Dataloading in PyTorch is a two step process. First, you need define a custom class which is subclassed from a data.Dataset class in This class takes in arguments which tell PyTorch about the location of the dataset, any "transforms" that you need to make before the dataset is loading.

First lets import the relevant modules. The data module which contains the useful functions for dataloading is contained in We also need a utility called glob to get all our batches in a list. The dataset is split into 5 different parts and we need to have them all together so we can use the entire training set:

In [5]:
import torch 
import as data 
import glob
from PIL import Image
import numpy as np

In [6]:

Every dataloader in PyTorch needs this form. The class you write for a DataLoader contains two methods __init__ and __getitem__.

The __init__ method is where you define arguments telling the location of your dataset, transforms (if any). It is also a good idea to have a variable which can hold a list of all images/batches in your dataset. There are quite a number of things going on in the function we wrote. Let's break it down:

  • We create two lists train_data and train_labels to contain the data and their associated labels. This is helpful since we have 5 training batches of 10,000 images each and it will be time consuming to do the same operation over and over again for all 5 batches.

  • We read in all batches at once in our for loop. This is a better approach almost every time and should be applied to any dataset that you're trying to read.

  • We use a funcion called np.concatenate to "concatenate" the training data which is in the form of numpy arrays. The labels are in the form of lists and hence we simply use the python concat operator + to concatenate them.

The __getitem__ method should accept only one argument: the index of the image/batch you want to access. We get to enjoy our hardwork as we can simply load a data and label by one index and return it. When writing your custom dataloaders it is a good practice to keep your __getitem__ as simple as possible.

Having written a dataloader, we can now test it. For the time being, don't worry about the test code, we'll get to it in later talks. The code below can be used for testing the dataloader class above:

In [8]:
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

tfs  = transforms.Compose([transforms.ToTensor(),
                          transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])# convert any data into a torch tensor


cifar_train = CIFARLoader(root, transform=tfs) # create a "CIFARLoader instance".
cifar_loader = data.DataLoader(cifar_train, batch_size=4, shuffle=True, num_workers=2)

data_iter = iter(cifar_loader)
data,label =

Loading batch:/home/akulshr/cifar-10-batches-py/data_batch_1
Loading batch:/home/akulshr/cifar-10-batches-py/data_batch_2
Loading batch:/home/akulshr/cifar-10-batches-py/data_batch_3
Loading batch:/home/akulshr/cifar-10-batches-py/data_batch_4
Loading batch:/home/akulshr/cifar-10-batches-py/data_batch_5
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])

In our test function we create an "instance" of the dataloader. Notice how we don't pass anything to the target_transformparameter. This is because our labels in this case are in the form of lists. If they are in a different form then we have to employ different strategies, but this is rare in practice. The output should show you torch.Size([3,32,32]) which is pytorch way of saying that your data is now a "tensor" with dimensions 3x32x32. This is correct since we know that CIFAR-10 data has dimensions 3x32x32.

In [ ]: