How to generate batches from a dataset and work with batch components


In [1]:
import sys
import numpy as np

# the following line is not required if BatchFlow is installed as a python package.
sys.path.append("../..")
from batchflow import Dataset, DatasetIndex, Batch

In [2]:
# number of items in the dataset
NUM_ITEMS = 10
# number of items in a batch when iterating
BATCH_SIZE = 3

Create a dataset

A dataset is defined by an index (a sequence of item ids) and a batch class (see the documentation for details).

In the simplest case an index is a natural sequence 0, 1, 2, 3, ...

So all you need to define the index is just a number of items in the dataset.


In [3]:
dataset = Dataset(index=NUM_ITEMS, batch_class=Batch)

The dataset index

See the documentation for more info about how to create an index which fits your needs.

Here are the most frequent use cases:

client_index = DatasetIndex(my_client_ids)
images_index = FilesIndex(path="/path/to/images/*.jpg", no_ext=True)

Iterate with gen_batch(...)

gen_batch is a python generator.


In [4]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
    print("batch", i, " contains items", batch.indices)


batch 0  contains items [0 1 2]
batch 1  contains items [3 4 5]
batch 2  contains items [6 7 8]
batch 3  contains items [9]

In [5]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_iters=5)):
    print("batch", i, " contains items", batch.indices)


batch 0  contains items [0 1 2]
batch 1  contains items [3 4 5]
batch 2  contains items [6 7 8]
batch 3  contains items [9 0 1]
batch 4  contains items [2 3 4]

drop_last=True skips the last batch if it contains fewer than BATCH_SIZE items


In [6]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1, drop_last=True)):
    print("batch", i, " contains items", batch.indices)


batch 0  contains items [0 1 2]
batch 1  contains items [3 4 5]
batch 2  contains items [6 7 8]

shuffle permutes items across batches


In [7]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_iters=4, drop_last=True, shuffle=True)):
    print("batch", i, " contains items", batch.indices)


batch 0  contains items [4 6 9]
batch 1  contains items [3 8 5]
batch 2  contains items [0 1 2]
batch 3  contains items [7 5 1]

Run the cell above multiple times to see how batches change.

Shuffle can be bool, int (seed number) or a RandomState object


In [8]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1, drop_last=True, shuffle=123)):
    print("batch", i, " contains items", batch.indices)


batch 0  contains items [4 0 7]
batch 1  contains items [5 8 3]
batch 2  contains items [1 6 9]

Run the cell above multiple times to see that batches stay the same across runs.

Iterate with next_batch(...)

While gen_batch is a generator, next_batch is an ordinary method. Most of the time you will use gen_batch, but for a deeper control over training and a more sophisticated finetuning next_batch might be more convenient.

If too many iterations are made, StopIteration will be raised.

Check that there are NUM_ITEMS * 3 iterations (i.e. 3 epochs) in loop, but n_epochs=2 is specified inside next_batch() call.


In [9]:
for i in range(NUM_ITEMS * 3):
    try:
        batch = dataset.next_batch(BATCH_SIZE, shuffle=True, n_epochs=2, drop_last=True)
        print("batch", i + 1, "contains items", batch.indices)
    except StopIteration:
        print("got StopIteration")
        break


batch 1 contains items [3 4 0]
batch 2 contains items [9 5 2]
batch 3 contains items [8 1 7]
batch 4 contains items [0 5 1]
batch 5 contains items [2 7 4]
batch 6 contains items [3 9 6]
got StopIteration

And finally with shuffle=True, n_epochs=None and a variable batch size

Do not forget to reset iterator to start next_batch'ing from scratch


In [10]:
dataset.reset('iter')

n_epochs=None allows for infinite iterations.


In [11]:
for i in range(int(NUM_ITEMS * 1.3)):
    batch = dataset.next_batch(BATCH_SIZE + (-1)**i * i % 3, shuffle=True, n_epochs=None, drop_last=True)
    print("batch", i + 1, "contains items", batch.indices)


batch 1 contains items [0 6 9]
batch 2 contains items [5 3 4 8 2]
batch 3 contains items [2 0 6 4 9]
batch 4 contains items [3 5 7]
batch 5 contains items [9 1 0 2]
batch 6 contains items [5 3 7 8]
batch 7 contains items [8 2 1]
batch 8 contains items [7 0 5 3 6]
batch 9 contains items [9 4 1 5 8]
batch 10 contains items [7 3 6]
batch 11 contains items [4 5 6 0]
batch 12 contains items [3 9 1 7]
batch 13 contains items [7 5 3]

To get a deeper understanding of drop_last read very important notes in the API.

Working with data

For illustrative purposes let's create a small array which will serve as a raw data source.


In [12]:
data = np.arange(NUM_ITEMS).reshape(-1, 1) * 100 + np.arange(3).reshape(1, -1)
data


Out[12]:
array([[  0,   1,   2],
       [100, 101, 102],
       [200, 201, 202],
       [300, 301, 302],
       [400, 401, 402],
       [500, 501, 502],
       [600, 601, 602],
       [700, 701, 702],
       [800, 801, 802],
       [900, 901, 902]])

Load data into a batch

After loading data is available as batch.data


In [13]:
for batch in dataset.gen_batch(BATCH_SIZE, n_epochs=1):
    batch = batch.load(src=data)
    print("batch contains items with indices", batch.indices)
    print('and batch data is')
    print(batch.data)
    print()


batch contains items with indices [0 1 2]
and batch data is
[[  0   1   2]
 [100 101 102]
 [200 201 202]]

batch contains items with indices [3 4 5]
and batch data is
[[300 301 302]
 [400 401 402]
 [500 501 502]]

batch contains items with indices [6 7 8]
and batch data is
[[600 601 602]
 [700 701 702]
 [800 801 802]]

batch contains items with indices [9]
and batch data is
[[900 901 902]]

You can easily iterate over batch items too


In [14]:
for batch in dataset.gen_batch(BATCH_SIZE, n_epochs=1):
    batch = batch.load(src=data)
    print("batch contains")
    for item in batch:
        print(item)
    print()


batch contains
[0 1 2]
[100 101 102]
[200 201 202]

batch contains
[300 301 302]
[400 401 402]
[500 501 502]

batch contains
[600 601 602]
[700 701 702]
[800 801 802]

batch contains
[900 901 902]

Data components

Not infrequently, the batch stores a more complex data structures, e.g. features and labels or images, masks, bounding boxes and labels. To work with these you might employ data components. Just define a property as follows:


In [15]:
class MyBatch(Batch):
    components = 'features', 'labels'

Let's generate some random data:


In [16]:
features_array = np.arange(NUM_ITEMS).reshape(-1, 1) * 100 + np.arange(3).reshape(1, -1)
labels_array = np.random.choice(10, size=NUM_ITEMS)
data = features_array, labels_array

Now create a dataset (preloaded handles data loading from data stored in memory)


In [17]:
dataset = Dataset(index=NUM_ITEMS, batch_class=MyBatch, preloaded=data)

Since components are defined, you can address them as batch and even item attributes (they are created and loaded automatically).


In [18]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
    print("batch", i, " contains items", batch.indices)
    print("and batch data consists of features:")
    print(batch.features)
    print("and labels:", batch.labels)
    print()


batch 0  contains items [0 1 2]
and batch data consists of features:
[[  0   1   2]
 [100 101 102]
 [200 201 202]]
and labels: [7 9 6]

batch 1  contains items [3 4 5]
and batch data consists of features:
[[300 301 302]
 [400 401 402]
 [500 501 502]]
and labels: [7 8 8]

batch 2  contains items [6 7 8]
and batch data consists of features:
[[600 601 602]
 [700 701 702]
 [800 801 802]]
and labels: [0 9 4]

batch 3  contains items [9]
and batch data consists of features:
[[900 901 902]]
and labels: [9]

You can iterate over batch items and change them on the fly


In [19]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
    print("Batch", i)
    for item in batch:
        print("item features:", item.features, "    item label:", item.labels)

    print()
    print("You can change batch data, even scalars.")
    for item in batch:
        item.features = item.features + 1000
        item.labels = item.labels + 100
    print("New batch features:\n", batch.features)
    print("and labels:", batch.labels)
    print()


Batch 0
item features: [0 1 2]     item label: 7
item features: [100 101 102]     item label: 9
item features: [200 201 202]     item label: 6

You can change batch data, even scalars.
New batch features:
 [[1000 1001 1002]
 [1100 1101 1102]
 [1200 1201 1202]]
and labels: [107 109 106]

Batch 1
item features: [300 301 302]     item label: 7
item features: [400 401 402]     item label: 8
item features: [500 501 502]     item label: 8

You can change batch data, even scalars.
New batch features:
 [[1300 1301 1302]
 [1400 1401 1402]
 [1500 1501 1502]]
and labels: [107 108 108]

Batch 2
item features: [600 601 602]     item label: 0
item features: [700 701 702]     item label: 9
item features: [800 801 802]     item label: 4

You can change batch data, even scalars.
New batch features:
 [[1600 1601 1602]
 [1700 1701 1702]
 [1800 1801 1802]]
and labels: [100 109 104]

Batch 3
item features: [900 901 902]     item label: 9

You can change batch data, even scalars.
New batch features:
 [[1900 1901 1902]]
and labels: [109]

Splitting a dataset

For machine learning tasks you might need to split a dataset into train, test and validation parts.


In [20]:
dataset.split(0.8)

Now the dataset is split into train / test in 80/20 ratio.


In [21]:
len(dataset.train), len(dataset.test)


Out[21]:
(8, 2)

In [22]:
dataset.split([.6, .2, .2])

In [23]:
len(dataset.train), len(dataset.test), len(dataset.validation)


Out[23]:
(6, 2, 2)

Dataset may be shuffled before splitting.


In [24]:
dataset.split(0.7, shuffle=True)

In [25]:
dataset.train.indices, dataset.test.indices


Out[25]:
(array([6, 2, 8, 4, 5, 7, 0]), array([3, 1, 9]))

As always, shuffle can be bool, int (seed number) or a RandomState object.

dataset.train and dataset.test are also datasets so you can do anything you want including splitting them further into dataset.train.train, etc.

Most of the time, though, you will work with pipelines, not datasets.

See pipeline operations tutorial for details or return to the table of contents.