Simple transfer learning with an Inception v3 architecture model.

This example shows how to take a Inception v3 architecture model trained on ImageNet images, and train a new top layer that can recognize other classes of images.

The top layer receives as input a 2048-dimensional vector for each image. We train a softmax layer on top of this representation. Assuming the softmax layer contains N labels, this corresponds to learning N + 2048*N model parameters corresponding to the learned biases and weights.

Here's an example, which assumes you have a folder containing class-named subfolders, each full of images for each label. The example folder flower_photos should have a structure like this:

~/flower_photos/daisy/photo1.jpg
~/flower_photos/daisy/photo2.jpg
...
~/flower_photos/rose/anotherphoto77.jpg
...
~/flower_photos/sunflower/somepicture.jpg

The subfolder names are important, since they define what label is applied to each image, but the filenames themselves don't matter. Once your images are prepared, you can run the training with a command like this:

bazel build third_party/tensorflow/examples/image_retraining:retrain && \
bazel-bin/third_party/tensorflow/examples/image_retraining/retrain \
--image_dir ~/flower_photos

You can replace the image_dir argument with any folder containing subfolders of images. The label for each image is taken from the name of the subfolder it's in.

This produces a new model file that can be loaded and run by any TensorFlow program, for example the label_image sample code.


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import glob
import hashlib
import os.path
import random
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

from tensorflow.python.client import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile

In [2]:
# These are all parameters that are tied to the particular model architecture
# we're using for Inception v3. These include things like tensor names and their
# sizes. If you want to adapt this script to work with another model, you will
# need to update these to reflect the values in the network you're using.
# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape'
BOTTLENECK_TENSOR_SIZE = 2048
MODEL_INPUT_WIDTH = 299
MODEL_INPUT_HEIGHT = 299
MODEL_INPUT_DEPTH = 3
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents'
RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear'

In [ ]:

Introduction

Since IPython Notebook is not a big fan of the tf.app.flags.FLAGS argument, we create our own synthetic class with all the parameters inside of it and and swap it into of the retrain class.


In [3]:
import retrain as nnrt
class retrain_parameters():
    image_dir = './'
    model_dir = './' # """Path to classify_image_graph_def.pb, """
                      #         """imagenet_synset_to_human_label_map.txt, and """
                      #         """imagenet_2012_challenge_label_map_proto.pbtxt."""
    output_graph =  '/tmp/output_graph.pb'
    output_labels = '/tmp/output_labels.txt' # """Where to save the trained graph's labels.""")
    # Details of the training configuration.
    how_many_training_steps = 4000 # """How many training steps to run before ending."""
    learning_rate = 0.01 # """How large a learning rate to use when training.""")
    
    testing_percentage = 10 # """What percentage of images to use as a test set."""
    validation_percentage = 10 #"""What percentage of images to use as a validation set.""")
    
    eval_step_interval = 10 # """How often to evaluate the training results.""")
    train_batch_size = 100 # """How many images to train on at a time.""")
    test_batch_size = 500 # """How many images to test on at a time. This"""
                          #      """ test set is only used infrequently to verify"""
                          #      """ the overall accuracy of the model.""")
    validation_batch_size = 100
    # """How many images to use in an evaluation batch. This validation set is"""
    # """ used much more often than the test set, and is an early indicator of"""
    # """ how accurate the model is during training.""")

    
    bottleneck_dir = '/tmp/bottleneck' # """Path to cache bottleneck layer values as files.""")
    final_tensor_name = 'final_result' #"""The name of the output classification layer in the retrained graph.""")

    # Controls the distortions used during training.
    flip_left_right = False # """Whether to randomly flip half of the training images horizontally.""")
    random_crop = 0 # """A percentage determining how much of a margin to randomly crop off the training images.""")
    random_scale = 0 # """A percentage determining how much to randomly scale up the size of the training images by.""")
    random_brightness = 0 # """A percentage determining how much to randomly multiply the training image input pixels up or down by.""")
FLAGS = retrain_parameters()
nnrt.FLAGS = FLAGS

In [ ]:
def maybe_download_and_extract():
    """Download and extract model tar file.

    If the pretrained model we're using doesn't already exist, this function
    downloads it from the TensorFlow.org website and unpacks it into a directory.
    """
    dest_directory = FLAGS.model_dir
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename = DATA_URL.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not os.path.exists(filepath):
        def _progress(count, block_size, total_size):
            sys.stdout.write('\r>> Downloading %s %.1f%%' %
                           (filename,
                            float(count * block_size) / float(total_size) * 100.0))
            sys.stdout.flush()

        filepath, _ = urllib.request.urlretrieve(DATA_URL,
                                                 filepath,
                                                 reporthook=_progress)
        print()
        statinfo = os.stat(filepath)
        print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def create_inception_graph():
    """"Creates a graph from saved GraphDef file and returns a Graph object.

    Returns:
    Graph holding the trained Inception network.
    """
    with tf.Session() as sess:
        with gfile.FastGFile(
            os.path.join(FLAGS.model_dir, 'classify_image_graph_def.pb'), 'r') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(graph_def, name='')
    return sess.graph

In [4]:
# Set up the pre-trained graph.

nnrt.maybe_download_and_extract()
graph = nnrt.create_inception_graph()

In [ ]:


In [ ]:
# Look at the folder structure, and create lists of all the images.
image_lists = nnrt.create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
                               FLAGS.validation_percentage)
class_count = len(image_lists.keys())
if class_count == 0:
    print('No valid folders of images found at ' + FLAGS.image_dir)
    
if class_count == 1:
    print('Only one valid folder of images found at ' + FLAGS.image_dir +
          ' - multiple classes are needed for classification.')

In [ ]:


In [ ]:
# See if the command-line flags mean we're applying any distortions.
do_distort_images = nnrt.should_distort_images(
  FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
  FLAGS.random_brightness)
ground_truth_tensor_name = 'ground_truth'
distorted_image_name = 'distorted_image'
distorted_jpeg_data_tensor_name = 'distorted_jpeg_data'
sess = tf.Session()

if do_distort_images:
    # We will be applying distortions, so set upthe operations we'll need.
    nnrt.add_input_distortions(FLAGS.flip_left_right, FLAGS.random_crop,
                          FLAGS.random_scale, FLAGS.random_brightness,
                          distorted_jpeg_data_tensor_name, distorted_image_name)
else:
    # We'll make sure we've calculated the 'bottleneck' image summaries and
    # cached them on disk.
    nnrt.cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir)

In [ ]:


In [ ]:
# Add the new layer that we'll be training.
train_step, cross_entropy = nnrt.add_final_training_ops(
  graph, len(image_lists.keys()), FLAGS.final_tensor_name,
  ground_truth_tensor_name)

# Set up all our weights to their initial default values.
init = tf.initialize_all_variables()
sess.run(init)

In [ ]:
# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step = nnrt.add_evaluation_step(graph, FLAGS.final_tensor_name,
                                    ground_truth_tensor_name)

# Get some layers we'll need to access during training.
bottleneck_tensor = graph.get_tensor_by_name(nnrt.ensure_name_has_port(
  BOTTLENECK_TENSOR_NAME))
ground_truth_tensor = graph.get_tensor_by_name(nnrt.ensure_name_has_port(
  ground_truth_tensor_name))

In [ ]:
# Run the training for as many cycles as requested on the command line.
for i in range(FLAGS.how_many_training_steps):
    # Get a catch of input bottleneck values, either calculated fresh every time
    # with distortions applied, or from the cache stored on disk.
    if do_distort_images:
      train_bottlenecks, train_ground_truth = nnrt.get_random_distorted_bottlenecks(
          sess, graph, image_lists, FLAGS.train_batch_size, 'training',
          FLAGS.image_dir, distorted_jpeg_data_tensor_name,
          distorted_image_name)
    else:
      train_bottlenecks, train_ground_truth = nnrt.get_random_cached_bottlenecks(
          sess, image_lists, FLAGS.train_batch_size, 'training',
          FLAGS.bottleneck_dir, FLAGS.image_dir)
    # Feed the bottlenecks and ground truth into the graph, and run a training
    # step.
    sess.run(train_step,
             feed_dict={bottleneck_tensor: train_bottlenecks,
                        ground_truth_tensor: train_ground_truth})
    # Every so often, print out how well the graph is training.
    is_last_step = (i + 1 == FLAGS.how_many_training_steps)
    if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
      train_accuracy, cross_entropy_value = sess.run(
          [evaluation_step, cross_entropy],
          feed_dict={bottleneck_tensor: train_bottlenecks,
                     ground_truth_tensor: train_ground_truth})
      print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
                                                      train_accuracy * 100))
      print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
                                                 cross_entropy_value))
      validation_bottlenecks, validation_ground_truth = (
          nnrt.get_random_cached_bottlenecks(
              sess, image_lists, FLAGS.validation_batch_size, 'validation',
              FLAGS.bottleneck_dir, FLAGS.image_dir))
      validation_accuracy = sess.run(
          evaluation_step,
          feed_dict={bottleneck_tensor: validation_bottlenecks,
                     ground_truth_tensor: validation_ground_truth})
      print('%s: Step %d: Validation accuracy = %.1f%%' %
            (datetime.now(), i, validation_accuracy * 100))

In [ ]:
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks(
  sess, image_lists, FLAGS.test_batch_size, 'testing',
  FLAGS.bottleneck_dir, FLAGS.image_dir)
test_accuracy = sess.run(
  evaluation_step,
  feed_dict={bottleneck_tensor: test_bottlenecks,
             ground_truth_tensor: test_ground_truth})
print('Final test accuracy = %.1f%%' % (test_accuracy * 100))

# Write out the trained graph and labels with the weights stored as constants.
output_graph_def = graph_util.convert_variables_to_constants(
  sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'w') as f:
    f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
    f.write('\n'.join(image_lists.keys()) + '\n')