In [1]:
import sys
import numpy as np
# the following line is not required if BatchFlow is installed as a python package.
sys.path.append("../..")
from batchflow import Dataset, DatasetIndex, Batch
In [2]:
# number of items in the dataset
NUM_ITEMS = 10
# number of items in a batch when iterating
BATCH_SIZE = 3
A dataset is defined by an index (a sequence of item ids) and a batch class (see the documentation for details).
In the simplest case an index is a natural sequence 0, 1, 2, 3, ...
So all you need to define the index is just a number of items in the dataset.
In [3]:
dataset = Dataset(index=NUM_ITEMS, batch_class=Batch)
See the documentation for more info about how to create an index which fits your needs.
Here are the most frequent use cases:
client_index = DatasetIndex(my_client_ids)
images_index = FilesIndex(path="/path/to/images/*.jpg", no_ext=True)
gen_batch
is a python generator.
In [4]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
print("batch", i, " contains items", batch.indices)
In [5]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_iters=5)):
print("batch", i, " contains items", batch.indices)
In [6]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1, drop_last=True)):
print("batch", i, " contains items", batch.indices)
In [7]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_iters=4, drop_last=True, shuffle=True)):
print("batch", i, " contains items", batch.indices)
Run the cell above multiple times to see how batches change.
In [8]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1, drop_last=True, shuffle=123)):
print("batch", i, " contains items", batch.indices)
Run the cell above multiple times to see that batches stay the same across runs.
While gen_batch
is a generator, next_batch
is an ordinary method.
Most of the time you will use gen_batch
, but for a deeper control over training and a more sophisticated finetuning next_batch
might be more convenient.
If too many iterations are made, StopIteration
will be raised.
Check that there are NUM_ITEMS * 3
iterations (i.e. 3 epochs) in loop, but n_epochs=2
is specified inside next_batch()
call.
In [9]:
for i in range(NUM_ITEMS * 3):
try:
batch = dataset.next_batch(BATCH_SIZE, shuffle=True, n_epochs=2, drop_last=True)
print("batch", i + 1, "contains items", batch.indices)
except StopIteration:
print("got StopIteration")
break
Do not forget to reset iterator to start next_batch
'ing from scratch
In [10]:
dataset.reset('iter')
n_epochs=None
allows for infinite iterations.
In [11]:
for i in range(int(NUM_ITEMS * 1.3)):
batch = dataset.next_batch(BATCH_SIZE + (-1)**i * i % 3, shuffle=True, n_epochs=None, drop_last=True)
print("batch", i + 1, "contains items", batch.indices)
To get a deeper understanding of drop_last
read very important notes in the API.
For illustrative purposes let's create a small array which will serve as a raw data source.
In [12]:
data = np.arange(NUM_ITEMS).reshape(-1, 1) * 100 + np.arange(3).reshape(1, -1)
data
Out[12]:
After loading data is available as batch.data
In [13]:
for batch in dataset.gen_batch(BATCH_SIZE, n_epochs=1):
batch = batch.load(src=data)
print("batch contains items with indices", batch.indices)
print('and batch data is')
print(batch.data)
print()
In [14]:
for batch in dataset.gen_batch(BATCH_SIZE, n_epochs=1):
batch = batch.load(src=data)
print("batch contains")
for item in batch:
print(item)
print()
Not infrequently, the batch stores a more complex data structures, e.g. features and labels or images, masks, bounding boxes and labels. To work with these you might employ data components. Just define a property as follows:
In [15]:
class MyBatch(Batch):
components = 'features', 'labels'
Let's generate some random data:
In [16]:
features_array = np.arange(NUM_ITEMS).reshape(-1, 1) * 100 + np.arange(3).reshape(1, -1)
labels_array = np.random.choice(10, size=NUM_ITEMS)
data = features_array, labels_array
Now create a dataset (preloaded
handles data loading from data stored in memory)
In [17]:
dataset = Dataset(index=NUM_ITEMS, batch_class=MyBatch, preloaded=data)
Since components are defined, you can address them as batch and even item attributes (they are created and loaded automatically).
In [18]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
print("batch", i, " contains items", batch.indices)
print("and batch data consists of features:")
print(batch.features)
print("and labels:", batch.labels)
print()
In [19]:
for i, batch in enumerate(dataset.gen_batch(BATCH_SIZE, n_epochs=1)):
print("Batch", i)
for item in batch:
print("item features:", item.features, " item label:", item.labels)
print()
print("You can change batch data, even scalars.")
for item in batch:
item.features = item.features + 1000
item.labels = item.labels + 100
print("New batch features:\n", batch.features)
print("and labels:", batch.labels)
print()
For machine learning tasks you might need to split a dataset into train, test and validation parts.
In [20]:
dataset.split(0.8)
Now the dataset is split into train / test in 80/20 ratio.
In [21]:
len(dataset.train), len(dataset.test)
Out[21]:
In [22]:
dataset.split([.6, .2, .2])
In [23]:
len(dataset.train), len(dataset.test), len(dataset.validation)
Out[23]:
Dataset may be shuffled before splitting.
In [24]:
dataset.split(0.7, shuffle=True)
In [25]:
dataset.train.indices, dataset.test.indices
Out[25]:
As always, shuffle can be bool, int (seed number) or a RandomState object.
dataset.train
and dataset.test
are also datasets so you can do anything you want including splitting them further into dataset.train.train
, etc.
Most of the time, though, you will work with pipelines, not datasets.
See pipeline operations tutorial for details or return to the table of contents.