TensorFlow Datasets

TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.

It handles downloading and preparing the data deterministically and constructing a tf.data.Dataset (or np.array).

Note: Do not confuse TFDS (this library) with tf.data (TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around tf.data. If you're not familiar with this API, we encourage you to read the official tf.data guide first.

TFDS exists in two packages:

  • tensorflow-datasets: The stable version, released every few months.
  • tfds-nightly: Released every day, contains the last versions of the datasets.

To install:

pip install tensorflow-datasets

Note: TFDS requires tensorflow (or tensorflow-gpu) to be already installed. TFDS support TF >=1.15.

This colab uses tfds-nightly and TF 2.

!pip install -q tensorflow>=2 tfds-nightly matplotlib

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Find available datasets

All dataset builders are subclass of tfds.core.DatasetBuilder. To get the list of available builders, uses tfds.list_builders() or look at our catalog.

Load a dataset

The easiest way of loading a dataset is tfds.load. It will:

  1. Download the data and save it as tfrecord files.
  2. Load the tfrecord and create the tf.data.Dataset.

ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)

Some common arguments:

  • split=: Which split to read (e.g. 'train', ['train', 'test'], 'train[80%:]',...). See our split API guide.
  • shuffle_files=: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).
  • data_dir=: Location where the dataset is saved ( defaults to ~/tensorflow_datasets/)
  • with_info=True: Returns the tfds.core.DatasetInfo containing dataset metadata
  • download=False: Disable download

tfds.load is a thin wrapper around tfds.core.DatasetBuilder. You can get the same output using the tfds.core.DatasetBuilder API:

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)

Iterate over a dataset

As dict

By default, the tf.data.Dataset object contains a dict of tf.Tensors:

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  image = example["image"]
  label = example["label"]
  print(image.shape, label)

As tuple

By using as_supervised=True, you can get a tuple (features, label) instead for supervised datasets.

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)

As numpy

Uses tfds.as_numpy to convert:

  • tf.Tensor -> np.array
  • tf.data.Dataset -> Generator[np.array]

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)

As batched tf.Tensor

By using batch_size=-1, you can load the full dataset in a single batch.

tfds.load will return a dict (tuple with as_supervised=True) of tf.Tensor (np.array with tfds.as_numpy).

Be careful that your dataset can fit in memory, and that all examples have the same shape.

image, label = tfds.as_numpy(tfds.load(

print(type(image), image.shape)

Build end-to-end pipeline

To go further, you can look:

Visualize a dataset

Visualize datasets with tfds.show_examples (only image datasets supported now):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

Access the dataset metadata

All builders include a tfds.core.DatasetInfo object containing the dataset metadata.

It can be accessed through:

  • The tfds.load API:

ds, info = tfds.load('mnist', with_info=True)
  • The tfds.core.DatasetBuilder API:

builder = tfds.builder('mnist')
info = builder.info

The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).

Features metadata (label names, image shape,...)

Access the tfds.features.FeatureDict:

Number of classes, label names:

print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')

Shapes, dtypes:

Split metadata (e.g. split names, number of examples,...)

Access the tfds.core.SplitDict:

Available splits:

Get info on individual split:

It also works with the subsplit API:

