Generating batches

In this example we will explore how to read a simple Lightning Memory-Mapped Database (LMDB) with pyxis using iterators.


In [1]:
from __future__ import print_function

import numpy as np

import pyxis as px

For consistency, we will be using a random number generator with a seed for some of the iterators.


In [2]:
rng = np.random.RandomState(1234)

Let's start by creating a small dataset of 10 samples. Each input is a randomly generated image with shape (254, 254, 3), while the targets are scalar values.


In [3]:
nb_samples = 10

X = rng.rand(nb_samples, 254, 254, 3)
y = np.arange(nb_samples, dtype=np.uint8)

The data is written using the pyxis writer.


In [4]:
db = px.Writer(dirpath='data', map_size_limit=30, ram_gb_limit=1)
db.put_samples('X', X, 'y', y)
db.close()

Using batch iterators

Read back the data using the pyxis reader.


In [5]:
db = px.Reader('data')

Example 1 - Number of samples is a multiple of the batch size

In this first example we create a (simple) batch iterator where the number of samples is divisible by the batch size.


In [6]:
gen = px.SimpleBatch(db, keys=('X', 'y'), batch_size=5, shuffle=False)

All the iterators that come with pyxis have the mandatory keys argument. The data returned by the iterator will be the values for which these keys point to. The order of the keys matter. For example, when using the keys ('a', 'b') the iterator will return (a_val, b_val), where a_val and b_val are the values associated with the keys 'a' and 'b', respectively.

The artificial dataset has 10 samples, so by letting the batch size be 5 it will take two iterations to go through the whole dataset. The artificial targets for four batches are printed out to showcase this.

endless is by default on, which means that after having gone through the dataset, the iterator will re-iterate over the data.


In [7]:
for i in range(4):
    xs, ys = next(gen)
    print()
    print('Iteration:', i, '\tTargets:', ys)
    if gen.end_of_dataset:
        print('We have reached the end of the dataset')


Iteration: 0 	Targets: [0 1 2 3 4]

Iteration: 1 	Targets: [5 6 7 8 9]
We have reached the end of the dataset

Iteration: 2 	Targets: [0 1 2 3 4]

Iteration: 3 	Targets: [5 6 7 8 9]
We have reached the end of the dataset

Example 2 - Number of samples is not a multiple of the batch size


In [8]:
gen = px.SimpleBatch(db, keys=('X', 'y'), batch_size=3, shuffle=False)

The artificial dataset has 10 samples, so by letting the batch size be 3 it will take four iterations to go through the whole dataset. The artificial targets for six batches are printed out to showcase this.

Notice that the final batch of the dataset only contains the remaining unseen samples.


In [9]:
for i in range(6):
    xs, ys = next(gen)
    print()
    print('Iteration:', i, '\tTargets:', ys)
    if gen.end_of_dataset:
        print('We have reached the end of the dataset')


Iteration: 0 	Targets: [0 1 2]

Iteration: 1 	Targets: [3 4 5]

Iteration: 2 	Targets: [6 7 8]

Iteration: 3 	Targets: [9]
We have reached the end of the dataset

Iteration: 4 	Targets: [0 1 2]

Iteration: 5 	Targets: [3 4 5]

Example 3 - Shuffling of data

Until now we have created batches by reading samples from the dataset in the order they were written. However, by turning shuffling on, the samples in the dataset will be reshuffled each time we go through the dataset.

Notice how we only request the values for the y key this time.


In [10]:
gen = px.SimpleBatch(db, keys=('y'), batch_size=5, shuffle=True, rng=rng)

In [11]:
for i in range(6):
    ys = next(gen)
    print()
    print('Iteration:', i, '\tTargets:', ys)
    if gen.end_of_dataset:
        print('We have reached the end of the dataset')


Iteration: 0 	Targets: [6 2 0 1 8]

Iteration: 1 	Targets: [7 3 5 4 9]
We have reached the end of the dataset

Iteration: 2 	Targets: [2 7 3 8 4]

Iteration: 3 	Targets: [1 5 0 9 6]
We have reached the end of the dataset

Iteration: 4 	Targets: [8 2 4 7 3]

Iteration: 5 	Targets: [0 1 6 5 9]
We have reached the end of the dataset

Example 4 - Stochastic batch iterator

Batches can be created stochastically. This means that the samples in a batch are sampled uniformly from the entire dataset. Here we showcase ten different batches with a batch size of five.


In [12]:
gen = px.StochasticBatch(db, keys=('y'), batch_size=5, rng=rng)

In [13]:
for i in range(10):
    ys = next(gen)
    print('Iteration:', i, '\tTargets:', ys)


Iteration: 0 	Targets: [2 6 3 3 4]
Iteration: 1 	Targets: [9 4 8 0 4]
Iteration: 2 	Targets: [3 2 2 4 4]
Iteration: 3 	Targets: [9 9 1 5 0]
Iteration: 4 	Targets: [3 9 4 1 5]
Iteration: 5 	Targets: [4 5 7 6 5]
Iteration: 6 	Targets: [8 6 6 2 0]
Iteration: 7 	Targets: [7 2 1 0 9]
Iteration: 8 	Targets: [1 5 6 0 7]
Iteration: 9 	Targets: [8 9 1 3 7]

Example 5 - Sequential batch iterator

Batches can be created by reading the database sequentially. This means that the samples in a batch are not shuffled, but can be read at a higher speed. The sequential batch iterator is ideal for very large datasets. Here we showcase ten different batches with a batch size of 3.


In [14]:
gen = px.SequentialBatch(db, keys=('y'), batch_size=3)

In [15]:
for i in range(10):
    ys = next(gen)
    print('Iteration:', i, '\tTargets:', ys)


Iteration: 0 	Targets: [0 1 2]
Iteration: 1 	Targets: [3 4 5]
Iteration: 2 	Targets: [6 7 8]
Iteration: 3 	Targets: [9]
Iteration: 4 	Targets: [0 1 2]
Iteration: 5 	Targets: [3 4 5]
Iteration: 6 	Targets: [6 7 8]
Iteration: 7 	Targets: [9]
Iteration: 8 	Targets: [0 1 2]
Iteration: 9 	Targets: [3 4 5]

Example 6 - Thread-safe iterators

The three types of iterators demonstrated so far are:

  • pyxis.SimpleBatch
  • pyxis.StochasticBatch
  • pyxis.SequentialBatch

Each of these come with a thread-safe variant. By thread-safe we mean that when more than one thread make use of the iterator it will not raise an exception. These variants have the suffix ThreadSafe:

  • pyxis.SimpleBatchThreadSafe
  • pyxis.StochasticBatchThreadSafe
  • pyxis.SequentialBatchThreadSafe

Other than being thread-safe, they work exactly the same as the non-thread-safe versions.

Custom iterators

The output of the types of iterators demonstrated above will always yield data as they were stored in the LMDB.

To create an iterator that modifies the data we can, for example, modify one of the existing iterators using inheritance. Here is an example where all targets are squared before they are output by the iterator. Notice how thread-safety is achieved by using the with statement with the Python lock object.


In [16]:
class SquareTargets(px.SimpleBatchThreadSafe):
    def __init__(self, db, keys, batch_size):
        super(SquareTargets, self).__init__(db, keys, batch_size,
                                            shuffle=False,
                                            endless=False)

    def __next__(self):
        with self.lock:
            X, y = next(self.gen)

        y = y ** 2

        return X, y

SquareTargets can now be used to generate batches of data from the LMDB.


In [17]:
gen = SquareTargets(db, keys=('X', 'y'), batch_size=2)

print('Squared targets:')
for _, y in gen:
    print(y)


Squared targets:
[0 1]
[4 9]
[16 25]
[36 49]
[64 81]

Close everything

We should make sure to close the LMDB environment after we are done reading.


In [18]:
db.close()