PyTorch dataset interface

In this example we will look at how a pyxis LMDB can be used with PyTorch's torch.utils.data.Dataset and torch.utils.data.DataLoader.


In [1]:
from __future__ import print_function

import numpy as np

import pyxis as px

As usual, we will begin by creating a small dataset to test with. It will consist of 10 samples, where each input observation has four features and targets are scalar values.


In [2]:
nb_samples = 10

X = np.outer(np.arange(1, nb_samples + 1, dtype=np.uint8), np.arange(1, 4 + 1, dtype=np.uint8))
y = np.arange(nb_samples, dtype=np.uint8)

for i in range(nb_samples):
    print('Input: {} -> Target: {}'.format(X[i], y[i]))


Input: [1 2 3 4] -> Target: 0
Input: [2 4 6 8] -> Target: 1
Input: [ 3  6  9 12] -> Target: 2
Input: [ 4  8 12 16] -> Target: 3
Input: [ 5 10 15 20] -> Target: 4
Input: [ 6 12 18 24] -> Target: 5
Input: [ 7 14 21 28] -> Target: 6
Input: [ 8 16 24 32] -> Target: 7
Input: [ 9 18 27 36] -> Target: 8
Input: [10 20 30 40] -> Target: 9

The data is written using a with statement.


In [3]:
with px.Writer(dirpath='data', map_size_limit=10, ram_gb_limit=1) as db:
    db.put_samples('input', X, 'target', y)

To be sure the data was stored correctly, we will read the data back - again using a with statement.


In [4]:
with px.Reader('data') as db:
    print(db)


pyxis.Reader
Location:		'data'
Number of samples:	10
Data keys (0th sample):
	'input' <- dtype: uint8, shape: (4,)
	'target' <- dtype: uint8, shape: ()

Working with PyTorch


In [5]:
try:
    import torch
    import torch.utils.data
except ImportError:
    raise ImportError('Could not import the PyTorch library `torch` or '
                      '`torch.utils.data`. Please refer to '
                      'https://pytorch.org/ for installation instructions.')

In pyxis.torch we have implemented a wrapper around torch.utils.data.Dataset called pyxis.torch.TorchDataset. This object is not imported into the pyxis name space because it relies on PyTorch being installed. As such, we first need to import pyxis.torch:


In [6]:
import pyxis.torch as pxt

pyxis.torch.TorchDataset has a single constructor argument: dirpath, i.e. the location of the pyxis LMDB.


In [7]:
dataset = pxt.TorchDataset('data')

The pyxis.torch.TorchDataset object has only three methods: __len__, __getitem__, and __repr__, each of which you can see an example of below:


In [8]:
len(dataset)


Out[8]:
10

In [9]:
dataset[0]


Out[9]:
{'input': tensor([ 1,  2,  3,  4], dtype=torch.uint8),
 'target': tensor(0, dtype=torch.uint8)}

In [10]:
dataset


Out[10]:
pyxis.Reader
Location:		'data'
Number of samples:	10
Data keys (0th sample):
	'input' <- dtype: uint8, shape: (4,)
	'target' <- dtype: uint8, shape: ()

pyxis.torch.TorchDataset can be directly combined with torch.utils.data.DataLoader to create an iterator type object:


In [11]:
use_cuda = True and torch.cuda.is_available()
kwargs = {"num_workers": 4, "pin_memory": True} if use_cuda else {}

loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, **kwargs)

for i, d in enumerate(loader):
    print('Batch:', i)
    print('\t', d['input'])
    print('\t', d['target'])


Batch: 0
	 tensor([[ 1,  2,  3,  4],
        [ 2,  4,  6,  8]], dtype=torch.uint8)
	 tensor([ 0,  1], dtype=torch.uint8)
Batch: 1
	 tensor([[  3,   6,   9,  12],
        [  4,   8,  12,  16]], dtype=torch.uint8)
	 tensor([ 2,  3], dtype=torch.uint8)
Batch: 2
	 tensor([[  5,  10,  15,  20],
        [  6,  12,  18,  24]], dtype=torch.uint8)
	 tensor([ 4,  5], dtype=torch.uint8)
Batch: 3
	 tensor([[  7,  14,  21,  28],
        [  8,  16,  24,  32]], dtype=torch.uint8)
	 tensor([ 6,  7], dtype=torch.uint8)
Batch: 4
	 tensor([[  9,  18,  27,  36],
        [ 10,  20,  30,  40]], dtype=torch.uint8)
	 tensor([ 8,  9], dtype=torch.uint8)

As with the built-in iterators in pyxis.iterators, we recommend you inherit from pyxis.torch.TorchDataset and alter __getitem__ to include your own data transformations.