In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlow Datasets 提供了一系列可以和 TensorFlow 配合使用的数据集。它负责下载和准备数据,以及构建 tf.data.Dataset
。
Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 官方英文文档。如果您有改进此翻译的建议, 请提交 pull request 到 tensorflow/docs GitHub 仓库。要志愿地撰写或者审核译文,请加入 docs-zh-cn@tensorflow.org Google Group。
In [ ]:
!pip install -q tensorflow tensorflow-datasets matplotlib
In [ ]:
import tensorflow as tf
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
TensorFlow 数据集同时兼容 TensorFlow Eager 模式 和图模式。在这个 colab 环境里面,我们的代码将通过 Eager 模式执行。
In [ ]:
tf.compat.v1.enable_eager_execution()
每一个数据集(dataset)都实现了抽象基类 tfds.core.DatasetBuilder
来构建。你可以通过 tfds.list_builders()
列出所有可用的数据集。
你可以通过 数据集文档页 查看所有支持的数据集及其补充文档。
In [ ]:
tfds.list_builders()
tfds.load
: 一行代码获取数据集tfds.load
是构建并加载 tf.data.Dataset
最简单的方式。
tf.data.Dataset
是构建输入流水线的标准 TensorFlow 接口。如果你对这个接口不熟悉,我们强烈建议你阅读 TensorFlow 官方指南。
下面,我们先加载 MNIST 训练数据。这个步骤会下载并准备好该数据,除非你显式指定 download=False
。值得注意的是,一旦该数据准备好了,后续的 load
命令便不会重新下载,可以重复使用准备好的数据。
你可以通过指定 data_dir=
(默认是 ~/tensorflow_datasets/
) 来自定义数据保存/加载的路径。
In [ ]:
mnist_train = tfds.load(name="mnist", split="train")
assert isinstance(mnist_train, tf.data.Dataset)
print(mnist_train)
加载数据集时,将使用规范的默认版本。 但是,建议你指定要使用的数据集的主版本,并在结果中表明使用了哪个版本的数据集。更多详情请参考数据集版本控制文档。
In [ ]:
mnist = tfds.load("mnist:1.*.*")
所有 tfds
数据集都包含将特征名称映射到 Tensor 值的特征字典。 典型的数据集(如 MNIST)将具有2个键:"image"
和 "label"
。 下面我们看一个例子。
注意:在图模式(graph mode)下,请参阅 tf.data 指南 以了解如何在 tf.data.Dataset
上进行迭代。
In [ ]:
for mnist_example in mnist_train.take(1): # 只取一个样本
image, label = mnist_example["image"], mnist_example["label"]
plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
In [ ]:
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
mnist_train = mnist_builder.as_dataset(split="train")
mnist_train
一旦有了 tf.data.Dataset
对象,就可以使用 tf.data
接口 定义适合模型训练的输入流水线的其余部分。
在这里,我们将重复使用数据集,以便有无限的样本流,然后随机打乱并创建大小为32的批。
In [ ]:
mnist_train = mnist_train.repeat().shuffle(1024).batch(32)
# prefetch 将使输入流水线可以在模型训练时异步获取批处理。
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)
# 现在你可以遍历数据集的批次并在 mnist_train 中训练批次:
# ...
In [ ]:
info = mnist_builder.info
print(info)
DatasetInfo
还包含特征相关的有用信息:
In [ ]:
print(info.features)
print(info.features["label"].num_classes)
print(info.features["label"].names)
你也可以通过 tfds.load
使用 with_info = True
加载 DatasetInfo
。
In [ ]:
mnist_test, info = tfds.load("mnist", split="test", with_info=True)
print(info)
In [ ]:
fig = tfds.show_examples(info, mnist_test)