TensorFlow Datasets

TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.

It handles downloading and preparing the data deterministically and constructing a tf.data.Dataset (or np.array).

Note: Do not confuse TFDS (this library) with tf.data (TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around tf.data. If you're not familiar with this API, we encourage you to read the official tf.data guide first.

Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

Installation

TFDS exists in two packages:

  • tensorflow-datasets: The stable version, released every few months.
  • tfds-nightly: Released every day, contains the last versions of the datasets.

To install:

pip install tensorflow-datasets

Note: TFDS requires tensorflow (or tensorflow-gpu) to be already installed. TFDS support TF >=1.15.

This colab uses tfds-nightly and TF 2.


In [0]:
!pip install -q tensorflow>=2 tfds-nightly matplotlib

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Find available datasets

All dataset builders are subclass of tfds.core.DatasetBuilder. To get the list of available builders, uses tfds.list_builders() or look at our catalog.


In [0]:
tfds.list_builders()

Load a dataset

The easiest way of loading a dataset is tfds.load. It will:

  1. Download the data and save it as tfrecord files.
  2. Load the tfrecord and create the tf.data.Dataset.

In [0]:
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)

Some common arguments:

  • split=: Which split to read (e.g. 'train', ['train', 'test'], 'train[80%:]',...). See our split API guide.
  • shuffle_files=: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).
  • data_dir=: Location where the dataset is saved ( defaults to ~/tensorflow_datasets/)
  • with_info=True: Returns the tfds.core.DatasetInfo containing dataset metadata
  • download=False: Disable download

tfds.load is a thin wrapper around tfds.core.DatasetBuilder. You can get the same output using the tfds.core.DatasetBuilder API:


In [0]:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)

Iterate over a dataset

As dict

By default, the tf.data.Dataset object contains a dict of tf.Tensors:


In [0]:
ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)

As tuple

By using as_supervised=True, you can get a tuple (features, label) instead for supervised datasets.


In [0]:
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)

As numpy

Uses tfds.as_numpy to convert:

  • tf.Tensor -> np.array
  • tf.data.Dataset -> Generator[np.array]

In [0]:
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)

As batched tf.Tensor

By using batch_size=-1, you can load the full dataset in a single batch.

tfds.load will return a dict (tuple with as_supervised=True) of tf.Tensor (np.array with tfds.as_numpy).

Be careful that your dataset can fit in memory, and that all examples have the same shape.


In [0]:
image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test', 
    batch_size=-1, 
    as_supervised=True,
))

print(type(image), image.shape)

Build end-to-end pipeline

To go further, you can look:

Visualize a dataset

Visualize datasets with tfds.show_examples (only image datasets supported now):


In [0]:
ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

Access the dataset metadata

All builders include a tfds.core.DatasetInfo object containing the dataset metadata.

It can be accessed through:

  • The tfds.load API:

In [0]:
ds, info = tfds.load('mnist', with_info=True)
  • The tfds.core.DatasetBuilder API:

In [0]:
builder = tfds.builder('mnist')
info = builder.info

The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).


In [0]:
print(info)

Features metadata (label names, image shape,...)

Access the tfds.features.FeatureDict:


In [0]:
info.features

Number of classes, label names:


In [0]:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))

Shapes, dtypes:


In [0]:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)

Split metadata (e.g. split names, number of examples,...)

Access the tfds.core.SplitDict:


In [0]:
print(info.splits)

Available splits:


In [0]:
print(list(info.splits.keys()))

Get info on individual split:


In [0]:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)

It also works with the subsplit API:


In [0]:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)

Citation

If you're using tensorflow-datasets for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the dataset catalog).

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets}},
}