In this example we will look at how a pyxis LMDB can be used with PyTorch's torch.utils.data.Dataset and torch.utils.data.DataLoader.
In [1]:
from __future__ import print_function
import numpy as np
import pyxis as px
As usual, we will begin by creating a small dataset to test with. It will consist of 10 samples, where each input observation has four features and targets are scalar values.
In [2]:
nb_samples = 10
X = np.outer(np.arange(1, nb_samples + 1, dtype=np.uint8), np.arange(1, 4 + 1, dtype=np.uint8))
y = np.arange(nb_samples, dtype=np.uint8)
for i in range(nb_samples):
print('Input: {} -> Target: {}'.format(X[i], y[i]))
The data is written using a with statement.
In [3]:
with px.Writer(dirpath='data', map_size_limit=10, ram_gb_limit=1) as db:
db.put_samples('input', X, 'target', y)
To be sure the data was stored correctly, we will read the data back - again using a with statement.
In [4]:
with px.Reader('data') as db:
print(db)
In [5]:
try:
import torch
import torch.utils.data
except ImportError:
raise ImportError('Could not import the PyTorch library `torch` or '
'`torch.utils.data`. Please refer to '
'https://pytorch.org/ for installation instructions.')
In pyxis.torch we have implemented a wrapper around torch.utils.data.Dataset called pyxis.torch.TorchDataset. This object is not imported into the pyxis name space because it relies on PyTorch being installed. As such, we first need to import pyxis.torch:
In [6]:
import pyxis.torch as pxt
pyxis.torch.TorchDataset has a single constructor argument: dirpath, i.e. the location of the pyxis LMDB.
In [7]:
dataset = pxt.TorchDataset('data')
The pyxis.torch.TorchDataset object has only three methods: __len__, __getitem__, and __repr__, each of which you can see an example of below:
In [8]:
len(dataset)
Out[8]:
In [9]:
dataset[0]
Out[9]:
In [10]:
dataset
Out[10]:
pyxis.torch.TorchDataset can be directly combined with torch.utils.data.DataLoader to create an iterator type object:
In [11]:
use_cuda = True and torch.cuda.is_available()
kwargs = {"num_workers": 4, "pin_memory": True} if use_cuda else {}
loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, **kwargs)
for i, d in enumerate(loader):
print('Batch:', i)
print('\t', d['input'])
print('\t', d['target'])
As with the built-in iterators in pyxis.iterators, we recommend you inherit from pyxis.torch.TorchDataset and alter __getitem__ to include your own data transformations.