Eager Execution Tutorial: Importing Data

This notebook demonstrates the use of the tf.data.Dataset API to build pipelines to feed data to your program. It covers:

  • Creating a Dataset.
  • Iteration over a Dataset with eager execution enabled.

We recommend using the Datasets API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.

If you're familiar with TensorFlow graphs, the API for constructing the Dataset object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler. You can use Python iteration over the tf.data.Dataset object and do not need to explicitly create an tf.data.Iterator object. As a result, the discussion on iterators in the Programmer's Guide is not relevant when eager execution is enabled.

Setup: Enable eager execution

In [0]:
# Import TensorFlow.
import tensorflow as tf

# Enable eager execution

Step 1: Create a source Dataset

Create a source dataset using one of the factory functions like Dataset.from_tensors, Dataset.from_tensor_slices or using objects that read from files like TextLineDataset or TFRecordDataset. See the Programmer's Guide for more information.

In [0]:
ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])

# Create a CSV file
import tempfile
_, filename = tempfile.mkstemp()
with open(filename, 'w') as f:
  f.write("""Line 1
Line 2
Line 3
ds_file = tf.data.TextLineDataset(filename)

Step 2: Apply transformations

Use the transformations functions like map, batch, shuffle etc. to apply transformations to the records of the dataset. See the API documentation for tf.data.Dataset for details.

In [0]:
ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)
ds_file = ds_file.batch(2)

Step 3: Iterate

When eager execution is enabled Dataset objects support iteration. If you're familiar with the use of Datasets in TensorFlow graphs, note that there is no need for calls to Dataset.make_one_shot_iterator() or get_next() calls.

In [0]:
print('Elements of ds_tensors:')
for x in ds_tensors:

print('\nElements in ds_file:')
for x in ds_file:

Elements of ds_tensors:
tf.Tensor([1 9], shape=(2,), dtype=int32)
tf.Tensor([16 25], shape=(2,), dtype=int32)
tf.Tensor([ 4 36], shape=(2,), dtype=int32)

Elements in ds_file:
tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)
tf.Tensor(['Line 3' '  '], shape=(2,), dtype=string)