In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

TensorFlow 数据集

TensorFlow Datasets 提供了一系列可以和 TensorFlow 配合使用的数据集。它负责下载和准备数据,以及构建 tf.data.Dataset

Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 官方英文文档。如果您有改进此翻译的建议, 请提交 pull request 到 tensorflow/docs GitHub 仓库。要志愿地撰写或者审核译文,请加入 docs-zh-cn@tensorflow.org Google Group

安装

pip install tensorflow-datasets

注意使用 tensorflow-datasets 的前提是已经安装好 TensorFlow,目前支持的版本是 tensorflow (或者 tensorflow-gpu) >= 1.15.0


In [ ]:
!pip install -q tensorflow tensorflow-datasets matplotlib

In [ ]:
import tensorflow as tf

In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Eager execution

TensorFlow 数据集同时兼容 TensorFlow Eager 模式 和图模式。在这个 colab 环境里面,我们的代码将通过 Eager 模式执行。


In [ ]:
tf.compat.v1.enable_eager_execution()

列出可用的数据集

每一个数据集(dataset)都实现了抽象基类 tfds.core.DatasetBuilder 来构建。你可以通过 tfds.list_builders() 列出所有可用的数据集。

你可以通过 数据集文档页 查看所有支持的数据集及其补充文档。


In [ ]:
tfds.list_builders()

tfds.load: 一行代码获取数据集

tfds.load 是构建并加载 tf.data.Dataset 最简单的方式。

tf.data.Dataset 是构建输入流水线的标准 TensorFlow 接口。如果你对这个接口不熟悉,我们强烈建议你阅读 TensorFlow 官方指南

下面,我们先加载 MNIST 训练数据。这个步骤会下载并准备好该数据,除非你显式指定 download=False。值得注意的是,一旦该数据准备好了,后续的 load 命令便不会重新下载,可以重复使用准备好的数据。 你可以通过指定 data_dir= (默认是 ~/tensorflow_datasets/) 来自定义数据保存/加载的路径。


In [ ]:
mnist_train = tfds.load(name="mnist", split="train")
assert isinstance(mnist_train, tf.data.Dataset)
print(mnist_train)

加载数据集时,将使用规范的默认版本。 但是,建议你指定要使用的数据集的主版本,并在结果中表明使用了哪个版本的数据集。更多详情请参考数据集版本控制文档


In [ ]:
mnist = tfds.load("mnist:1.*.*")

特征字典

所有 tfds 数据集都包含将特征名称映射到 Tensor 值的特征字典。 典型的数据集(如 MNIST)将具有2个键:"image""label"。 下面我们看一个例子。

注意:在图模式(graph mode)下,请参阅 tf.data 指南 以了解如何在 tf.data.Dataset 上进行迭代。


In [ ]:
for mnist_example in mnist_train.take(1):  # 只取一个样本
    image, label = mnist_example["image"], mnist_example["label"]

    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
    print("Label: %d" % label.numpy())

DatasetBuilder

tfds.load 实际上是一个基于 DatasetBuilder 的简单方便的包装器。我们可以直接使用 MNIST DatasetBuilder 实现与上述相同的操作。


In [ ]:
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
mnist_train = mnist_builder.as_dataset(split="train")
mnist_train

输入流水线

一旦有了 tf.data.Dataset 对象,就可以使用 tf.data 接口 定义适合模型训练的输入流水线的其余部分。

在这里,我们将重复使用数据集,以便有无限的样本流,然后随机打乱并创建大小为32的批。


In [ ]:
mnist_train = mnist_train.repeat().shuffle(1024).batch(32)

# prefetch 将使输入流水线可以在模型训练时异步获取批处理。
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)

# 现在你可以遍历数据集的批次并在 mnist_train 中训练批次:
#   ...

数据集信息(DatasetInfo)

生成后,构建器将包含有关数据集的有用信息:


In [ ]:
info = mnist_builder.info
print(info)

DatasetInfo 还包含特征相关的有用信息:


In [ ]:
print(info.features)
print(info.features["label"].num_classes)
print(info.features["label"].names)

你也可以通过 tfds.load 使用 with_info = True 加载 DatasetInfo


In [ ]:
mnist_test, info = tfds.load("mnist", split="test", with_info=True)
print(info)

可视化

对于图像分类数据集,你可以使用 tfds.show_examples 来可视化一些样本。


In [ ]:
fig = tfds.show_examples(info, mnist_test)