Dataset API examples

Starting to hack on these. Not ready yet :)


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
print('This code requires TensorFlow v1.4+')
print('You have:', tf.__version__)


This code requires TensorFlow v1.4+
You have: 1.4.0-dev20170921

In [42]:
def generator():
    # Your regular Python input processing code goes here
    for i in range(10):
        yield ("foo", 1.0)

def input_fn():
    dataset = (
        tf.contrib.data.Dataset
        # In TensorFlow 1.4, create a Dataset from a Python generator.
        # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/data/python/ops/dataset_ops.py
        # Just list the datatypes that are are returned.
        .from_generator(generator, (tf.string, tf.float64))
        # Randomly shuffle using a buffer of 10000 examples.
        .shuffle(10000)
        # Repeat for 10 epochs.
        .repeat(10)
        # Combine 32 consecutive elements into a batch.
        .batch(32)
        # Use prefetch() to overlap the producer and consumer
        # for a little performance optimization.
        .prefetch(1)
    )
    return dataset.make_one_shot_iterator().get_next()

# Now we can have an input function. We
# can use this to train a TensorFlow Estimator.
in_fn = input_fn()

# Or, we can demo it, like this.
# Pretty cool.
with tf.Session() as sess:
    stuff = sess.run(in_fn)
    print(stuff)


(array(['foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo',
       'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo',
       'foo', 'foo'], dtype=object), array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.]))