Example of TFRecords creation for a sequence


In [11]:
import os

import numpy as np

import tensorflow as tf

from sklearn.utils import shuffle


%matplotlib inline
import matplotlib.pyplot as plt

In [6]:
# Load MNIST data into numpy arrays
(X_trn, y_trn), (X_tst, y_tst) = tf.keras.datasets.mnist.load_data()

X_trn = np.reshape(X_trn, [X_trn.shape[0], 28, 28, 1])
X_tst = np.reshape(X_tst, [X_tst.shape[0], 28, 28, 1])
print(X_trn.shape)
print(y_trn.shape)


(60000, 28, 28, 1)
(60000,)

In [22]:
# Functions

def sequence_generator(X, y, batch_size=32, seq_size=3):
    '''
    '''
    X, y = shuffle(X, y)
    while 1:
        start = np.random.randint(len(X)-(seq_size*batch_size))
        seq_x = [X[start]]
        seq_y = [y[start]]
        start += 1
        for j in range(seq_size-1):
            seq_x += [X[start]]
            seq_y += [y[start]]
            start += 1
        image = np.concatenate(seq_x, axis=1)
        yield image, seq_y

s = sequence_generator(X_trn, y_trn, batch_size=32, seq_size=3)
img, l = next(s)

plt.imshow(img[:,:,0], cmap='gray')
print(type(l))


<class 'list'>

In [23]:
def _int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))



def _bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

In [24]:
def create_sequence_tfrecord(images, labels, size, output_file):
    """Converts a file to TFRecords."""
    print('Generating %s' % output_file)
    with tf.python_io.TFRecordWriter(output_file) as record_writer:
        
        s = sequence_generator(images, labels, batch_size=32, seq_size=3)
        for i in range(size):
            image, label = next(s)
            example = tf.train.Example(features=tf.train.Features(
                feature={
                        'image': _bytes_feature(image.tobytes()),
                        'label': _int64_feature(label)
                        }))
            record_writer.write(example.SerializeToString())
    print('Done!')

In [25]:
trn_tfrecords_file = '/tmp/trn.tfrecord'
create_sequence_tfrecord(X_trn, y_trn, 500, trn_tfrecords_file)

trn_tfrecords_file = '/tmp/tst.tfrecord'
create_sequence_tfrecord(X_tst, y_tst, 100, trn_tfrecords_file)


Generating /tmp/trn.tfrecord
Done!
Generating /tmp/tst.tfrecord
Done!

Create the parser and the input_fn functions


In [31]:
DEPTH = 1
HEIGHT = 28
WIDTH = 28*3

def mnist_parser(serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([3], tf.int64),
        })
    image = tf.decode_raw(features['image'], tf.uint8)
    image.set_shape([DEPTH * HEIGHT * WIDTH])

    # Reshape from [depth * height * width] to [depth, height, width].
    image = tf.cast(
        tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
        tf.float32)
    label = tf.cast(features['label'], tf.int32)

    # Custom preprocessing.
    #image = self.preprocess(image)

    return image, label

In [32]:
def train_input_fn(TFfilenames, batch_size):
    """An input function for training"""
    
    dataset = tf.data.TFRecordDataset(TFfilenames)
    dataset = dataset.map(mnist_parser, num_parallel_calls=1)
    
    # Shuffle, repeat, and batch the examples.
    dataset = dataset.cache().shuffle(buffer_size=1000).repeat().batch(batch_size)

    # Generate iterator and return the next elements of the iterator
    # in 1.6 and above you can pass directly the dataset and the estimator build internaly the iterator.
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

In [33]:
def test_input_fn(TFfilenames, batch_size):
    # ... Pending
    return (images, labels)

In [34]:
# Define our input pipeline. Pin it to the CPU so that the GPU can be reserved
# for forward and backwards propogation.

tf.reset_default_graph()

batch_size = 32
with tf.device('/cpu:0'):
    train_images, train_labels = train_input_fn(trn_tfrecords_file, batch_size)

Check the tfrecord content


In [35]:
# Sanity check that all is correct


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    with tf.contrib.slim.queues.QueueRunners(sess):
        sample_images, sample_labels = sess.run([train_images, train_labels])

plt.imshow(sample_images[0,:,:,0], cmap='gray')
print(sample_labels)


[[8 2 2]
 [2 8 0]
 [7 0 9]
 [1 8 5]
 [1 7 8]
 [0 5 2]
 [2 2 0]
 [1 0 4]
 [7 3 2]
 [2 4 5]
 [3 8 3]
 [4 6 1]
 [7 9 8]
 [4 9 2]
 [1 6 6]
 [0 6 8]
 [2 0 5]
 [7 0 8]
 [8 1 6]
 [7 5 2]
 [6 0 4]
 [8 0 0]
 [1 0 6]
 [8 0 8]
 [8 5 4]
 [7 1 2]
 [0 3 3]
 [0 4 9]
 [7 1 8]
 [0 2 5]
 [1 3 3]
 [0 9 1]]

In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]: