TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.
It handles downloading and preparing the data deterministically and constructing a tf.data.Dataset
(or np.array
).
Note: Do not confuse TFDS (this library) with tf.data
(TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around tf.data
. If you're not familiar with this API, we encourage you to read the official tf.data guide first.
Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0
TFDS exists in two packages:
tensorflow-datasets
: The stable version, released every few months.tfds-nightly
: Released every day, contains the last versions of the datasets.To install:
pip install tensorflow-datasets
Note: TFDS requires tensorflow
(or tensorflow-gpu
) to be already installed. TFDS support TF >=1.15.
This colab uses tfds-nightly
and TF 2.
In [0]:
!pip install -q tensorflow>=2 tfds-nightly matplotlib
In [0]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
All dataset builders are subclass of tfds.core.DatasetBuilder
. To get the list of available builders, uses tfds.list_builders()
or look at our catalog.
In [0]:
tfds.list_builders()
The easiest way of loading a dataset is tfds.load
. It will:
tfrecord
files.tfrecord
and create the tf.data.Dataset
.
In [0]:
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
Some common arguments:
split=
: Which split to read (e.g. 'train'
, ['train', 'test']
, 'train[80%:]'
,...). See our split API guide.shuffle_files=
: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).data_dir=
: Location where the dataset is saved (
defaults to ~/tensorflow_datasets/
)with_info=True
: Returns the tfds.core.DatasetInfo
containing dataset metadatadownload=False
: Disable downloadtfds.load
is a thin wrapper around tfds.core.DatasetBuilder
. You can get the same output using the tfds.core.DatasetBuilder
API:
In [0]:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
In [0]:
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
In [0]:
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
In [0]:
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
In [0]:
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
To go further, you can look:
In [0]:
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
In [0]:
ds, info = tfds.load('mnist', with_info=True)
tfds.core.DatasetBuilder
API:
In [0]:
builder = tfds.builder('mnist')
info = builder.info
The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).
In [0]:
print(info)
In [0]:
info.features
Number of classes, label names:
In [0]:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
Shapes, dtypes:
In [0]:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
In [0]:
print(info.splits)
Available splits:
In [0]:
print(list(info.splits.keys()))
Get info on individual split:
In [0]:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
It also works with the subsplit API:
In [0]:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
If you're using tensorflow-datasets
for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the dataset catalog).
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://www.tensorflow.org/datasets}},
}