TF-TRT Inference from Saved Model with TensorFlow 2

In this notebook, we demonstrate the process to create a TF-TRT optimized model from a Tensorflow saved model.

This notebook was designed to run with TensorFlow versions 2.x which is included as part of NVIDIA NGC Tensorflow containers from version nvcr.io/nvidia/tensorflow:19.12-tf2-py3, that can be downloaded from the NGC website.

Notebook Content

  1. Pre-requisite: data and model
  2. Verifying the orignal FP32 model
  3. Creating TF-TRT FP32 model
  4. Creating TF-TRT FP16 model
  5. Creating TF-TRT INT8 model
  6. Calibrating TF-TRT INT8 model with raw JPEG images

Quick start

We will run this demonstration with a saved Resnet-v1-50 model, to be downloaded and stored at /path/to/saved_model.

The INT8 calibration process requires access to a small but representative sample of real training or valiation data.

We will use the ImageNet dataset that is stored in TFrecords format. Google provide an excellent all-in-one script for downloading and preparing the ImageNet dataset at

https://github.com/tensorflow/models/blob/master/research/inception/inception/data/download_and_preprocess_imagenet.sh.

To run this notebook, start the NGC TF container, providing correct path to the ImageNet validation data /path/to/image_net and the folder /path/to/saved_model containing the TF saved model:

nvidia-docker run --rm -it -p 8888:8888 -v /path/to/image_net:/data  -v /path/to/saved_model:/saved_model --name TFTRT nvcr.io/nvidia/tensorflow:19.12-tf2-py3

Within the container, we then start Jupyter notebook with:

jupyter notebook --ip 0.0.0.0 --port 8888  --allow-root

Connect to Jupyter notebook web interface on your host http://localhost:8888.

1. Pre-requisite: data and model

We first install some extra packages and external dependencies needed for, e.g. preprocessing ImageNet data.


In [ ]:
%%bash
pushd /workspace/nvidia-examples/tensorrt/tftrt/examples/object_detection/ 
bash install_dependencies.sh;
popd

In [ ]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

import time
import logging
import numpy as np

import tensorflow as tf
print("TensorFlow version: ", tf.__version__)

from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants

logging.getLogger("tensorflow").setLevel(logging.ERROR)

# check TensorRT version
print("TensorRT version: ")
!dpkg -l | grep nvinfer

Data

We verify that the correct ImageNet data folder has been mounted and validation data files of the form validation-00xxx-of-00128 are available.


In [ ]:
def get_files(data_dir, filename_pattern):
    if data_dir == None:
        return []
    files = tf.io.gfile.glob(os.path.join(data_dir, filename_pattern))
    if files == []:
        raise ValueError('Can not find any files in {} with '
                         'pattern "{}"'.format(data_dir, filename_pattern))
    return files

In [ ]:
VALIDATION_DATA_DIR = "/data"
validation_files = get_files(VALIDATION_DATA_DIR, 'validation*')
print('There are %d validation files. \n%s\n%s\n...'%(len(validation_files), validation_files[0], validation_files[-1]))

TF saved model

If not already downloaded, we will be downloading and working with a ResNet-50 v1 checkpoint from https://github.com/tensorflow/models/tree/master/official/resnet


In [ ]:
%%bash
FILE=/saved_model/resnet_v1_50_2016_08_28.tar.gz
if [ -f $FILE ]; then
   echo "The file '$FILE' exists."
else
   echo "The file '$FILE' in not found. Downloading..."
   wget -P /saved_model/ http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NHWC.tar.gz
fi

tar -xzvf /saved_model/resnet_v1_fp32_savedmodel_NHWC.tar.gz -C /saved_model

Helper functions

We define a few helper functions to read and preprocess Imagenet data from TFRecord files.


In [ ]:
def deserialize_image_record(record):
    feature_map = {
        'image/encoded':          tf.io.FixedLenFeature([ ], tf.string, ''),
        'image/class/label':      tf.io.FixedLenFeature([1], tf.int64,  -1),
        'image/class/text':       tf.io.FixedLenFeature([ ], tf.string, ''),
        'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32)
    }
    with tf.name_scope('deserialize_image_record'):
        obj = tf.io.parse_single_example(record, feature_map)
        imgdata = obj['image/encoded']
        label   = tf.cast(obj['image/class/label'], tf.int32)
        bbox    = tf.stack([obj['image/object/bbox/%s'%x].values
                            for x in ['ymin', 'xmin', 'ymax', 'xmax']])
        bbox = tf.transpose(tf.expand_dims(bbox, 0), [0,2,1])
        text    = obj['image/class/text']
        return imgdata, label, bbox, text

In [ ]:
from preprocessing import vgg_preprocess as vgg_preprocessing
def preprocess(record):
        # Parse TFRecord
        imgdata, label, bbox, text = deserialize_image_record(record)
        #label -= 1 # Change to 0-based if not using background class
        try:    image = tf.image.decode_jpeg(imgdata, channels=3, fancy_upscaling=False, dct_method='INTEGER_FAST')
        except: image = tf.image.decode_png(imgdata, channels=3)

        image = vgg_preprocessing(image, 224, 224)
        return image, label

In [ ]:
#Define some global variables
BATCH_SIZE = 64

2. Verifying the orignal FP32 model

We demonstrate the conversion process with a Resnet-50 v1 model. First, we inspect the original Tensorflow model.


In [ ]:
SAVED_MODEL_DIR =  "/saved_model/resnet_v1_fp32_savedmodel_NHWC/1538686669/"

We employ saved_model_cli to inspect the inputs and outputs of the model.


In [ ]:
!saved_model_cli show --all --dir $SAVED_MODEL_DIR

This give us information on the input and output tensors as input_tensor:0 and softmax_tensor:0 respectively. Also note that the number of output classes here is 1001 instead of 1000 Imagenet classes. This is because the network was trained with an extra background class.


In [ ]:
INPUT_TENSOR = 'input_tensor:0'
OUTPUT_TENSOR = 'softmax_tensor:0'

Next, we define a function to read in a saved mode, measuring its speed and accuracy on the validation data.


In [ ]:
def benchmark_saved_model(SAVED_MODEL_DIR, BATCH_SIZE=64):
    # load saved model
    saved_model_loaded = tf.saved_model.load(SAVED_MODEL_DIR, tags=[tag_constants.SERVING])
    signature_keys = list(saved_model_loaded.signatures.keys())
    print(signature_keys)

    infer = saved_model_loaded.signatures['serving_default']
    print(infer.structured_outputs)

    # prepare dataset iterator
    dataset = tf.data.TFRecordDataset(validation_files)   
    dataset = dataset.map(map_func=preprocess, num_parallel_calls=20)
    dataset = dataset.batch(batch_size=BATCH_SIZE, drop_remainder=True) 

    print('Warming up for 50 batches...')
    cnt = 0
    for x, y in dataset:
        labeling = infer(x)
        cnt += 1
        if cnt == 50:
            break

    print('Benchmarking inference engine...')
    num_hits = 0
    num_predict = 0
    start_time = time.time()
    for x, y in dataset:
        labeling = infer(x)
        preds = labeling['classes'].numpy()
        num_hits += np.sum(preds == y)
        num_predict += preds.shape[0]
        
    print('Accuracy: %.2f%%'%(100*num_hits/num_predict))
    print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))

In [ ]:
benchmark_saved_model(SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)

3. Creating TF-TRT FP32 model

Next, we convert the native TF FP32 model to TF-TRT FP32, then verify model accuracy and inference speed.


In [ ]:
FP32_SAVED_MODEL_DIR = SAVED_MODEL_DIR+"_TFTRT_FP32/1"
!rm -rf $FP32_SAVED_MODEL_DIR

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode=trt.TrtPrecisionMode.FP32)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert()

converter.save(FP32_SAVED_MODEL_DIR)


benchmark_saved_model(FP32_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)

4. Creating TF-TRT FP16 model

Next, we convert the native TF FP32 model to TF-TRT FP16, then verify model accuracy and inference speed.


In [ ]:
FP16_SAVED_MODEL_DIR = SAVED_MODEL_DIR+"_TFTRT_FP16/1"
!rm -rf $FP16_SAVED_MODEL_DIR

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode=trt.TrtPrecisionMode.FP16)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert()

converter.save(FP16_SAVED_MODEL_DIR)

In [ ]:
benchmark_saved_model(FP16_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)

5. Creating TF-TRT INT8 model

Creating TF-TRT INT8 inference model requires two steps:

  • Step 1: Prepare a calibration dataset

  • Step 2: Convert and calibrate the TF-TRT INT8 inference engine

Step 1: Prepare a calibration dataset

Creating TF-TRT INT8 model requires a small calibration dataset. This data set ideally should represent the test data in production well, and will be used to create a value histogram for each layer in the neural network for effective 8-bit quantization.


In [ ]:
num_calibration_batches = 2

# prepare calibration dataset
dataset = tf.data.TFRecordDataset(validation_files)   
dataset = dataset.map(map_func=preprocess, num_parallel_calls=20)
dataset = dataset.batch(batch_size=BATCH_SIZE, drop_remainder=True) 
calibration_dataset = dataset.take(num_calibration_batches)

def calibration_input_fn():
    for x, y in calibration_dataset:
        yield (x, )

Step 2: Convert and calibrate the TF-TRT INT8 inference engine

The calibration step may take a while to complete.


In [ ]:
# set a directory to write the saved model
INT8_SAVED_MODEL_DIR =  SAVED_MODEL_DIR + "_TFTRT_INT8/1"
!rm -rf $INT8_SAVED_MODEL_DIR

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode=trt.TrtPrecisionMode.INT8)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert(calibration_input_fn=calibration_input_fn)

converter.save(INT8_SAVED_MODEL_DIR)

Benchmarking INT8 saved model

Finally we reload and verify the accuracy and performance of the INT8 saved model from disk.


In [ ]:
benchmark_saved_model(INT8_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)

In [ ]:
!saved_model_cli show --all --dir $INT8_SAVED_MODEL_DIR

6. Calibrating TF-TRT INT8 model with raw JPEG images

As an alternative to taking data in TFRecords format, in this section, we demonstrate the process of calibrating TFTRT INT-8 model from a directory of raw JPEG images. We asume that raw images have been mounted to the directory /data/Calibration_data.

As a rule of thumb, calibration data should be a small but representative set of images that is similar to what is expected in deployment. Empirically, for common network architectures trained on imagenet data, calibration data of size 500-1000 provide good accuracy. As such, a good strategy for a dataset such as imagenet is to choose one sample from each class.


In [ ]:
data_directory = "/data/Calibration_data"
calibration_files = [os.path.join(path, name) for path, _, files in os.walk(data_directory) for name in files]
print('There are %d calibration files. \n%s\n%s\n...'%(len(calibration_files), calibration_files[0], calibration_files[-1]))

We define a helper function to read and preprocess image from JPEG file.


In [ ]:
def parse_file(filepath):
    image = tf.io.read_file(filepath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = vgg_preprocessing(image, 224, 224)
    return image

In [ ]:
num_calibration_batches = 2

# prepare calibration dataset
dataset = tf.data.Dataset.from_tensor_slices(calibration_files)
dataset = dataset.map(map_func=parse_file, num_parallel_calls=20)
dataset = dataset.batch(batch_size=BATCH_SIZE)
dataset = dataset.repeat(None)
calibration_dataset = dataset.take(num_calibration_batches)

def calibration_input_fn():
    for x in calibration_dataset:
        yield (x, )

Next, we proceed with the two-stage process of creating and calibrating TFTRT INT8 model.

Convert and calibrate the TF-TRT INT8 inference engine


In [ ]:
# set a directory to write the saved model
INT8_SAVED_MODEL_DIR =  SAVED_MODEL_DIR + "_TFTRT_INT8/2"
!rm -rf $INT8_SAVED_MODEL_DIR

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode=trt.TrtPrecisionMode.INT8)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert(calibration_input_fn=calibration_input_fn)

converter.save(INT8_SAVED_MODEL_DIR)

As before, we can benchmark the speed and accuracy of the resulting model.


In [ ]:
benchmark_saved_model(INT8_SAVED_MODEL_DIR)

Conclusion

In this notebook, we have demonstrated the process of creating TF-TRT inference model from an original TF FP32 saved model. In every case, we have also verified the accuracy and speed to the resulting model.