In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
print('This code requires TensorFlow v1.4+')
print('You have:', tf.__version__)
In [42]:
def generator():
# Your regular Python input processing code goes here
for i in range(10):
yield ("foo", 1.0)
def input_fn():
dataset = (
tf.contrib.data.Dataset
# In TensorFlow 1.4, create a Dataset from a Python generator.
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/data/python/ops/dataset_ops.py
# Just list the datatypes that are are returned.
.from_generator(generator, (tf.string, tf.float64))
# Randomly shuffle using a buffer of 10000 examples.
.shuffle(10000)
# Repeat for 10 epochs.
.repeat(10)
# Combine 32 consecutive elements into a batch.
.batch(32)
# Use prefetch() to overlap the producer and consumer
# for a little performance optimization.
.prefetch(1)
)
return dataset.make_one_shot_iterator().get_next()
# Now we can have an input function. We
# can use this to train a TensorFlow Estimator.
in_fn = input_fn()
# Or, we can demo it, like this.
# Pretty cool.
with tf.Session() as sess:
stuff = sess.run(in_fn)
print(stuff)