Eager Execution Tutorial: Importing Data

This notebook demonstrates the use of the tf.contrib.data.Dataset API to build pipelines to feed data to your program. It covers:

  • Creating a Dataset.
  • Iteration over a Dataset with eager execution enabled.

We recommend using the Datasets API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.

If you're familiar with TensorFlow graphs, the API for constructing the Dataset object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic Iterator() class instead of using make_one_shot_iterator() and get_next(). As a result, the discussion on iterators in the Programmer's Guide is not relevant when eager execution is enabled.

Setup: Enable eager execution


In [0]:
# Import TensorFlow.
import tensorflow as tf

# Import TensorFlow eager execution support (subject to future changes).
import tensorflow.contrib.eager as tfe

# Enable eager execution
tfe.enable_eager_execution()

Step 1: Create a source Dataset

Create a source dataset using one of the factory functions like Dataset.from_tensors, Dataset.from_tensor_slices or using objects that read from files like TextLineDataset or TFRecordDataset. See the Programmer's Guide for more information.


In [0]:
ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])

# Create a CSV file
import tempfile
_, filename = tempfile.mkstemp()
with open(filename, 'w') as f:
  f.write("""Line 1
Line 2
Line 3
  """)
ds_file = tf.contrib.data.TextLineDataset(filename)

Step 2: Apply transformations

Use the transformations functions like map, batch, shuffle etc. to apply transformations to the records of the dataset. See the API documentation for tf.contrib.data.Dataset for details.


In [0]:
ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)
ds_file = ds_file.batch(2)

Step 3: Iterate

Use tfe.Iterator on the Dataset object to get a Python iterator over the contents of the dataset.

If you're familiar with the use of Datasets in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to Dataset.make_one_shot_iterator() and no get_next() calls.


In [5]:
print('Elements of ds_tensors:')
for x in tfe.Iterator(ds_tensors):
  print(x)

print('\nElements in ds_file:')
for x in tfe.Iterator(ds_file):
  print(x)


Elements of ds_tensors:
tf.Tensor([4 9], shape=(2,), dtype=int32)
tf.Tensor([16 25], shape=(2,), dtype=int32)
tf.Tensor([36  1], shape=(2,), dtype=int32)

Elements in ds_file:
tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)
tf.Tensor(['Line 3' '  '], shape=(2,), dtype=string)