This notebook demonstrates the use of the tf.data.Dataset
API to build pipelines to feed data to your program. It covers:
Dataset
.Dataset
with eager execution enabled.We recommend using the Dataset
s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.
If you're familiar with TensorFlow graphs, the API for constructing the Dataset
object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic Iterator()
class instead of using make_one_shot_iterator()
and get_next()
. As a result, the discussion on iterators in the Programmer's Guide is not relevant when eager execution is enabled.
In [0]:
# Import TensorFlow.
import tensorflow as tf
# Import TensorFlow eager execution support (subject to future changes).
import tensorflow.contrib.eager as tfe
# Enable eager execution
tfe.enable_eager_execution()
Dataset
Create a source dataset using one of the factory functions like Dataset.from_tensors
, Dataset.from_tensor_slices
or using objects that read from files like TextLineDataset
or TFRecordDataset
. See the Programmer's Guide for more information.
In [0]:
ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
# Create a CSV file
import tempfile
_, filename = tempfile.mkstemp()
with open(filename, 'w') as f:
f.write("""Line 1
Line 2
Line 3
""")
ds_file = tf.data.TextLineDataset(filename)
Use the transformations functions like map
, batch
, shuffle
etc. to apply transformations to the records of the dataset. See the API documentation for tf.data.Dataset
for details.
In [0]:
ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)
ds_file = ds_file.batch(2)
Use tfe.Iterator
on the Dataset
object to get a Python iterator over the contents of the dataset.
If you're familiar with the use of Dataset
s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to Dataset.make_one_shot_iterator()
and no get_next()
calls.
In [5]:
print('Elements of ds_tensors:')
for x in tfe.Iterator(ds_tensors):
print(x)
print('\nElements in ds_file:')
for x in tfe.Iterator(ds_file):
print(x)