Example of TFRecords creation


In [1]:
import os
import sys

import numpy as np

import tensorflow as tf

import argparse


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
# Load MNIST data into numpy arrays
(X_trn, y_trn), (X_tst, y_tst) = tf.keras.datasets.mnist.load_data()

X_trn = np.reshape(X_trn, [X_trn.shape[0], 28, 28, 1])
X_tst = np.reshape(X_tst, [X_tst.shape[0], 28, 28, 1])
print(X_trn.shape)
print(y_trn.shape)


(60000, 28, 28, 1)
(60000,)

In [3]:
def _int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

In [4]:
def convert_arrays_to_tfrecord(images, labels, output_file):
    """Converts a file to TFRecords."""
    print('Generating %s' % output_file)
    with tf.python_io.TFRecordWriter(output_file) as record_writer:
        for image, label in zip(images, labels):
            example = tf.train.Example(features=tf.train.Features(
                feature={
                        'image': _bytes_feature(image.tobytes()),
                        'label': _int64_feature(label)
                        }))
            record_writer.write(example.SerializeToString())
    print('Done!')

In [5]:
trn_tfrecords_file = '/tmp/trn.tfrecord'
convert_arrays_to_tfrecord(X_trn, y_trn, trn_tfrecords_file)

tst_tfrecords_file = '/tmp/tst.tfrecord'
convert_arrays_to_tfrecord(X_tst, y_tst, trn_tfrecords_file)


Generating /tmp/trn.tfrecord
Done!
Generating /tmp/trn.tfrecord
Done!

Create the parser and the input_fn functions


In [6]:
DEPTH = 1
HEIGHT = 28
WIDTH = 28

def mnist_parser(serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features['image'], tf.uint8)
    image.set_shape([DEPTH * HEIGHT * WIDTH])

    # Reshape from [depth * height * width] to [depth, height, width].
    image = tf.cast(
        tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
        tf.float32)
    label = tf.cast(features['label'], tf.int32)

    # Custom preprocessing.
    #image = self.preprocess(image)

    return image, label

In [7]:
def train_input_fn(TFfilenames, batch_size):
    """An input function for training"""
    
    dataset = tf.data.TFRecordDataset(TFfilenames)
    dataset = dataset.map(mnist_parser, num_parallel_calls=1)
    
    # Shuffle, repeat, and batch the examples.
    dataset = dataset.cache().shuffle(buffer_size=1000).repeat().batch(batch_size)

    # Generate iterator and return the next elements of the iterator
    # in 1.6 and above you can pass directly the dataset and the estimator build internaly the iterator.
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

In [8]:
def test_input_fn(TFfilenames, batch_size):
    # ... Pending
    return (images, labels)

In [9]:
# Define our input pipeline. Pin it to the CPU so that the GPU can be reserved
# for forward and backwards propogation.

tf.reset_default_graph()

batch_size = 32
with tf.device('/cpu:0'):
    train_images, train_labels = train_input_fn(trn_tfrecords_file, batch_size)

Check the tfrecord content


In [10]:
# Sanity check that all is correct
%matplotlib inline
import matplotlib.pyplot as plt

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    with tf.contrib.slim.queues.QueueRunners(sess):
        sample_images, sample_labels = sess.run([train_images, train_labels])

plt.imshow(sample_images[0,:,:,0], cmap='gray')
print(sample_labels)


[2 2 5 9 5 4 6 1 1 2 0 6 3 7 0 8 4 4 1 2 3 8 8 2 1 4 0 2 8 3 4 6]

In [ ]:

Use a default estimator


In [11]:
class Model(tf.keras.models.Model):
  """Model to recognize digits in the MNIST dataset.
  """

  def __init__(self):
        
    # Define layers to use in the model
    self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(32, 5, padding='same', activation=tf.nn.relu)
    self.max_pool2d = tf.layers.MaxPooling2D((2, 2), (2, 2), padding='same')
    
    self.conv2 = tf.layers.Conv2D(64, 5, padding='same', activation=tf.nn.relu)
    
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.dropout = tf.layers.Dropout(0.4)
    
    self.fc2 = tf.layers.Dense(10)

    
  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.
    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.
    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)

In [12]:
# Define the model


# Define the model_function compatible with tf.estimators
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    image = features
    if isinstance(image, dict):
        image = features['image']
    
    # Instanciate the model
    model = Model()
    

    # Train step
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)

        logits = model(image, training=True)
    
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, axis=1))
        tf.identity(accuracy[1], name='train_accuracy')
    
        tf.summary.scalar('train_accuracy', accuracy[1])
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=loss,
            train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))



    if mode == tf.estimator.ModeKeys.EVAL:
        logits = model(image, training=False)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=loss,
            eval_metric_ops={
                'accuracy':
                    tf.metrics.accuracy(
                        labels=labels,
                        predictions=tf.argmax(logits, axis=1)),
            })


    if mode == tf.estimator.ModeKeys.PREDICT:
        logits = model(image, training=False)
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

In [13]:
def main(unused_argv):
    
    # Create classifier
    mnist_classifier = tf.estimator.Estimator(
          model_fn=model_fn,
          model_dir='/tmp/mnist',
          params={})

    # Set up training hook that logs the training accuracy every 100 steps.
    tensors_to_log = {'train_accuracy': 'train_accuracy'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=10)

    # Train the model
    mnist_classifier.train(input_fn=lambda:train_input_fn(trn_tfrecords_file, FLAGS.batch_size),
                           hooks=[logging_hook], max_steps=FLAGS.train_steps)

    
    # Evaluate the model and print results
    eval_results = mnist_classifier.evaluate(input_fn=lambda:train_input_fn(tst_tfrecords_file, FLAGS.batch_size))
    print()
    print('Evaluation results:\n\t%s' % eval_results)

    # Export the model
    image = tf.placeholder(tf.float32, [None, 28, 28, 1])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({'image': image})
    mnist_classifier.export_savedmodel('/tmp/mnist_model/', input_fn)
    
#tf.estimator.Estimator.export_savedmodel()    
    

if __name__ == '__main__':
    
    # Define the arguments of the program
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--train_steps', default=1000, type=int,
                        help='number of training steps')

    tf.logging.set_verbosity(tf.logging.INFO)
    FLAGS, unparsed = parser.parse_known_args()
    
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_steps': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fc51817c400>, '_task_id': 0, '_save_summary_steps': 100, '_service': None, '_tf_random_seed': None, '_log_step_count_steps': 100, '_model_dir': '/tmp/mnist', '_is_chief': True, '_master': '', '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_keep_checkpoint_every_n_hours': 10000, '_session_config': None, '_save_checkpoints_secs': 600, '_keep_checkpoint_max': 5, '_task_type': 'worker'}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from /tmp/mnist/model.ckpt-100
INFO:tensorflow:Saving checkpoints for 101 into /tmp/mnist/model.ckpt.
INFO:tensorflow:train_accuracy = 0.84375
INFO:tensorflow:loss = 1.00334, step = 101
INFO:tensorflow:train_accuracy = 0.84375 (0.198 sec)
INFO:tensorflow:train_accuracy = 0.875 (0.059 sec)
INFO:tensorflow:train_accuracy = 0.851562 (0.061 sec)
INFO:tensorflow:train_accuracy = 0.83125 (0.057 sec)
INFO:tensorflow:train_accuracy = 0.838542 (0.056 sec)
INFO:tensorflow:train_accuracy = 0.834821 (0.058 sec)
INFO:tensorflow:train_accuracy = 0.832031 (0.058 sec)
INFO:tensorflow:train_accuracy = 0.840278 (0.062 sec)
INFO:tensorflow:train_accuracy = 0.85 (0.064 sec)
INFO:tensorflow:global_step/sec: 136.794
INFO:tensorflow:train_accuracy = 0.84375 (0.060 sec)
INFO:tensorflow:loss = 1.26865, step = 201 (0.732 sec)
INFO:tensorflow:train_accuracy = 0.851562 (0.057 sec)
INFO:tensorflow:train_accuracy = 0.848558 (0.055 sec)
INFO:tensorflow:train_accuracy = 0.852679 (0.059 sec)
INFO:tensorflow:train_accuracy = 0.8625 (0.065 sec)
INFO:tensorflow:train_accuracy = 0.857422 (0.065 sec)
INFO:tensorflow:train_accuracy = 0.862132 (0.063 sec)
INFO:tensorflow:train_accuracy = 0.864583 (0.063 sec)
INFO:tensorflow:train_accuracy = 0.866776 (0.064 sec)
INFO:tensorflow:train_accuracy = 0.867188 (0.067 sec)
INFO:tensorflow:global_step/sec: 159.047
INFO:tensorflow:train_accuracy = 0.870536 (0.071 sec)
INFO:tensorflow:loss = 0.55366, step = 301 (0.630 sec)
INFO:tensorflow:train_accuracy = 0.875 (0.057 sec)
INFO:tensorflow:train_accuracy = 0.879076 (0.059 sec)
INFO:tensorflow:train_accuracy = 0.880208 (0.065 sec)
INFO:tensorflow:train_accuracy = 0.8825 (0.067 sec)
INFO:tensorflow:train_accuracy = 0.887019 (0.068 sec)
INFO:tensorflow:train_accuracy = 0.888889 (0.067 sec)
INFO:tensorflow:train_accuracy = 0.891741 (0.073 sec)
INFO:tensorflow:train_accuracy = 0.894397 (0.069 sec)
INFO:tensorflow:train_accuracy = 0.891667 (0.043 sec)
INFO:tensorflow:global_step/sec: 163.315
INFO:tensorflow:train_accuracy = 0.893145 (0.043 sec)
INFO:tensorflow:loss = 0.539735, step = 401 (0.612 sec)
INFO:tensorflow:train_accuracy = 0.892578 (0.043 sec)
INFO:tensorflow:train_accuracy = 0.892992 (0.044 sec)
INFO:tensorflow:train_accuracy = 0.894301 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.895536 (0.041 sec)
INFO:tensorflow:train_accuracy = 0.895833 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.896959 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.899671 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.89984 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.9 (0.039 sec)
INFO:tensorflow:global_step/sec: 247.879
INFO:tensorflow:train_accuracy = 0.902439 (0.042 sec)
INFO:tensorflow:loss = 0.0332606, step = 501 (0.403 sec)
INFO:tensorflow:train_accuracy = 0.903274 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.903343 (0.035 sec)
INFO:tensorflow:train_accuracy = 0.903409 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.902083 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.902853 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.904255 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.905599 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.90625 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.9075 (0.039 sec)
INFO:tensorflow:global_step/sec: 265.028
INFO:tensorflow:train_accuracy = 0.90625 (0.039 sec)
INFO:tensorflow:loss = 0.376897, step = 601 (0.377 sec)
INFO:tensorflow:train_accuracy = 0.907452 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.908019 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.907986 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.909659 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.910156 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.911184 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.912177 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.912606 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.913542 (0.035 sec)
INFO:tensorflow:global_step/sec: 269.442
INFO:tensorflow:train_accuracy = 0.913934 (0.036 sec)
INFO:tensorflow:loss = 0.18219, step = 701 (0.371 sec)
INFO:tensorflow:train_accuracy = 0.91381 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.914683 (0.035 sec)
INFO:tensorflow:train_accuracy = 0.915527 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.916346 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.917614 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.91791 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.918199 (0.036 sec)
INFO:tensorflow:train_accuracy = 0.918478 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.919196 (0.039 sec)
INFO:tensorflow:global_step/sec: 270.664
INFO:tensorflow:train_accuracy = 0.919894 (0.041 sec)
INFO:tensorflow:loss = 0.0421399, step = 801 (0.371 sec)
INFO:tensorflow:train_accuracy = 0.920573 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.921233 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.921875 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.922083 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.923109 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.924107 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.92508 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.926028 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.925 (0.035 sec)
INFO:tensorflow:global_step/sec: 259.249
INFO:tensorflow:train_accuracy = 0.925926 (0.040 sec)
INFO:tensorflow:loss = 0.00482497, step = 901 (0.385 sec)
INFO:tensorflow:train_accuracy = 0.926448 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.927334 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.928199 (0.041 sec)
INFO:tensorflow:train_accuracy = 0.928309 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.928416 (0.040 sec)
INFO:tensorflow:train_accuracy = 0.928879 (0.037 sec)
INFO:tensorflow:train_accuracy = 0.929688 (0.038 sec)
INFO:tensorflow:train_accuracy = 0.930477 (0.039 sec)
INFO:tensorflow:train_accuracy = 0.93125 (0.041 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/mnist/model.ckpt.
INFO:tensorflow:Loss for final step: 0.126534.
INFO:tensorflow:Starting evaluation at 2018-04-11-16:16:32
INFO:tensorflow:Restoring parameters from /tmp/mnist/model.ckpt-1000
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-13-126e86f7f6c4> in <module>()
     40     FLAGS, unparsed = parser.parse_known_args()
     41 
---> 42     tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/platform/app.py in run(main, argv)
     46   # Call the main function, passing through any arguments
     47   # to the final program.
---> 48   _sys.exit(main(_sys.argv[:1] + flags_passthrough))
     49 
     50 

<ipython-input-13-126e86f7f6c4> in main(unused_argv)
     17 
     18     # Evaluate the model and print results
---> 19     eval_results = mnist_classifier.evaluate(input_fn=lambda:train_input_fn(tst_tfrecords_file, FLAGS.batch_size))
     20     print()
     21     print('Evaluation results:\n\t%s' % eval_results)

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py in evaluate(self, input_fn, steps, hooks, checkpoint_path, name)
    353         hooks=hooks,
    354         checkpoint_path=checkpoint_path,
--> 355         name=name)
    356 
    357   def _convert_eval_steps_to_hooks(self, steps):

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py in _evaluate_model(self, input_fn, hooks, checkpoint_path, name)
    837           final_ops=eval_dict,
    838           hooks=all_hooks,
--> 839           config=self._session_config)
    840 
    841       _write_dict_to_summary(

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/evaluation.py in _evaluate_once(checkpoint_path, master, scaffold, eval_ops, feed_dict, final_ops, final_ops_feed_dict, hooks, config)
    204     if eval_ops is not None:
    205       while not session.should_stop():
--> 206         session.run(eval_ops, feed_dict)
    207 
    208   logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
    519                           feed_dict=feed_dict,
    520                           options=options,
--> 521                           run_metadata=run_metadata)
    522 
    523   def should_stop(self):

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
    890                               feed_dict=feed_dict,
    891                               options=options,
--> 892                               run_metadata=run_metadata)
    893       except _PREEMPTION_ERRORS as e:
    894         logging.info('An error was raised. This may be due to a preemption in '

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
    950   def run(self, *args, **kwargs):
    951     try:
--> 952       return self._sess.run(*args, **kwargs)
    953     except _PREEMPTION_ERRORS:
    954       raise

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
   1022                                   feed_dict=feed_dict,
   1023                                   options=options,
-> 1024                                   run_metadata=run_metadata)
   1025 
   1026     for hook in self._hooks:

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
    825 
    826   def run(self, *args, **kwargs):
--> 827     return self._sess.run(*args, **kwargs)
    828 
    829 

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1321   def _do_call(self, fn, *args):
   1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:
   1325       message = compat.as_text(e.message)

~/anaconda3/envs/tf14/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1300           return tf_session.TF_Run(session, options,
   1301                                    feed_dict, fetch_list, target_list,
-> 1302                                    status, run_metadata)
   1303 
   1304     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]:


In [ ]:


In [ ]: