In this notebook, we demonstrate the process to create a TF-TRT optimized model from a Tensorflow saved model.
This notebook was designed to run with TensorFlow versions 2.x which is included as part of NVIDIA NGC Tensorflow containers from version nvcr.io/nvidia/tensorflow:19.12-tf2-py3
, that can be downloaded from the NGC website.
We will run this demonstration with a saved Resnet-v1-50 model, to be downloaded and stored at /path/to/saved_model
.
The INT8 calibration process requires access to a small but representative sample of real training or valiation data.
We will use the ImageNet dataset that is stored in TFrecords format. Google provide an excellent all-in-one script for downloading and preparing the ImageNet dataset at
To run this notebook, start the NGC TF container, providing correct path to the ImageNet validation data /path/to/image_net
and the folder /path/to/saved_model
containing the TF saved model:
nvidia-docker run --rm -it -p 8888:8888 -v /path/to/image_net:/data -v /path/to/saved_model:/saved_model --name TFTRT nvcr.io/nvidia/tensorflow:19.12-tf2-py3
Within the container, we then start Jupyter notebook with:
jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root
Connect to Jupyter notebook web interface on your host http://localhost:8888.
We first install some extra packages and external dependencies needed for, e.g. preprocessing ImageNet data.
In [ ]:
%%bash
pushd /workspace/nvidia-examples/tensorrt/tftrt/examples/object_detection/
bash install_dependencies.sh;
popd
In [ ]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import time
import logging
import numpy as np
import tensorflow as tf
print("TensorFlow version: ", tf.__version__)
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants
logging.getLogger("tensorflow").setLevel(logging.ERROR)
# check TensorRT version
print("TensorRT version: ")
!dpkg -l | grep nvinfer
In [ ]:
def get_files(data_dir, filename_pattern):
if data_dir == None:
return []
files = tf.io.gfile.glob(os.path.join(data_dir, filename_pattern))
if files == []:
raise ValueError('Can not find any files in {} with '
'pattern "{}"'.format(data_dir, filename_pattern))
return files
In [ ]:
VALIDATION_DATA_DIR = "/data"
validation_files = get_files(VALIDATION_DATA_DIR, 'validation*')
print('There are %d validation files. \n%s\n%s\n...'%(len(validation_files), validation_files[0], validation_files[-1]))
If not already downloaded, we will be downloading and working with a ResNet-50 v1 checkpoint from https://github.com/tensorflow/models/tree/master/official/resnet
In [ ]:
%%bash
FILE=/saved_model/resnet_v1_50_2016_08_28.tar.gz
if [ -f $FILE ]; then
echo "The file '$FILE' exists."
else
echo "The file '$FILE' in not found. Downloading..."
wget -P /saved_model/ http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NHWC.tar.gz
fi
tar -xzvf /saved_model/resnet_v1_fp32_savedmodel_NHWC.tar.gz -C /saved_model
In [ ]:
def deserialize_image_record(record):
feature_map = {
'image/encoded': tf.io.FixedLenFeature([ ], tf.string, ''),
'image/class/label': tf.io.FixedLenFeature([1], tf.int64, -1),
'image/class/text': tf.io.FixedLenFeature([ ], tf.string, ''),
'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32)
}
with tf.name_scope('deserialize_image_record'):
obj = tf.io.parse_single_example(record, feature_map)
imgdata = obj['image/encoded']
label = tf.cast(obj['image/class/label'], tf.int32)
bbox = tf.stack([obj['image/object/bbox/%s'%x].values
for x in ['ymin', 'xmin', 'ymax', 'xmax']])
bbox = tf.transpose(tf.expand_dims(bbox, 0), [0,2,1])
text = obj['image/class/text']
return imgdata, label, bbox, text
In [ ]:
from preprocessing import vgg_preprocess as vgg_preprocessing
def preprocess(record):
# Parse TFRecord
imgdata, label, bbox, text = deserialize_image_record(record)
#label -= 1 # Change to 0-based if not using background class
try: image = tf.image.decode_jpeg(imgdata, channels=3, fancy_upscaling=False, dct_method='INTEGER_FAST')
except: image = tf.image.decode_png(imgdata, channels=3)
image = vgg_preprocessing(image, 224, 224)
return image, label
In [ ]:
#Define some global variables
BATCH_SIZE = 64
In [ ]:
SAVED_MODEL_DIR = "/saved_model/resnet_v1_fp32_savedmodel_NHWC/1538686669/"
We employ saved_model_cli
to inspect the inputs and outputs of the model.
In [ ]:
!saved_model_cli show --all --dir $SAVED_MODEL_DIR
This give us information on the input and output tensors as input_tensor:0
and softmax_tensor:0
respectively. Also note that the number of output classes here is 1001 instead of 1000 Imagenet classes. This is because the network was trained with an extra background class.
In [ ]:
INPUT_TENSOR = 'input_tensor:0'
OUTPUT_TENSOR = 'softmax_tensor:0'
Next, we define a function to read in a saved mode, measuring its speed and accuracy on the validation data.
In [ ]:
def benchmark_saved_model(SAVED_MODEL_DIR, BATCH_SIZE=64):
# load saved model
saved_model_loaded = tf.saved_model.load(SAVED_MODEL_DIR, tags=[tag_constants.SERVING])
signature_keys = list(saved_model_loaded.signatures.keys())
print(signature_keys)
infer = saved_model_loaded.signatures['serving_default']
print(infer.structured_outputs)
# prepare dataset iterator
dataset = tf.data.TFRecordDataset(validation_files)
dataset = dataset.map(map_func=preprocess, num_parallel_calls=20)
dataset = dataset.batch(batch_size=BATCH_SIZE, drop_remainder=True)
print('Warming up for 50 batches...')
cnt = 0
for x, y in dataset:
labeling = infer(x)
cnt += 1
if cnt == 50:
break
print('Benchmarking inference engine...')
num_hits = 0
num_predict = 0
start_time = time.time()
for x, y in dataset:
labeling = infer(x)
preds = labeling['classes'].numpy()
num_hits += np.sum(preds == y)
num_predict += preds.shape[0]
print('Accuracy: %.2f%%'%(100*num_hits/num_predict))
print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))
In [ ]:
benchmark_saved_model(SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)
In [ ]:
FP32_SAVED_MODEL_DIR = SAVED_MODEL_DIR+"_TFTRT_FP32/1"
!rm -rf $FP32_SAVED_MODEL_DIR
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt.TrtPrecisionMode.FP32)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert()
converter.save(FP32_SAVED_MODEL_DIR)
benchmark_saved_model(FP32_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)
In [ ]:
FP16_SAVED_MODEL_DIR = SAVED_MODEL_DIR+"_TFTRT_FP16/1"
!rm -rf $FP16_SAVED_MODEL_DIR
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt.TrtPrecisionMode.FP16)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert()
converter.save(FP16_SAVED_MODEL_DIR)
In [ ]:
benchmark_saved_model(FP16_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)
Creating TF-TRT INT8 inference model requires two steps:
Step 1: Prepare a calibration dataset
Step 2: Convert and calibrate the TF-TRT INT8 inference engine
Creating TF-TRT INT8 model requires a small calibration dataset. This data set ideally should represent the test data in production well, and will be used to create a value histogram for each layer in the neural network for effective 8-bit quantization.
In [ ]:
num_calibration_batches = 2
# prepare calibration dataset
dataset = tf.data.TFRecordDataset(validation_files)
dataset = dataset.map(map_func=preprocess, num_parallel_calls=20)
dataset = dataset.batch(batch_size=BATCH_SIZE, drop_remainder=True)
calibration_dataset = dataset.take(num_calibration_batches)
def calibration_input_fn():
for x, y in calibration_dataset:
yield (x, )
In [ ]:
# set a directory to write the saved model
INT8_SAVED_MODEL_DIR = SAVED_MODEL_DIR + "_TFTRT_INT8/1"
!rm -rf $INT8_SAVED_MODEL_DIR
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt.TrtPrecisionMode.INT8)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert(calibration_input_fn=calibration_input_fn)
converter.save(INT8_SAVED_MODEL_DIR)
In [ ]:
benchmark_saved_model(INT8_SAVED_MODEL_DIR, BATCH_SIZE=BATCH_SIZE)
In [ ]:
!saved_model_cli show --all --dir $INT8_SAVED_MODEL_DIR
As an alternative to taking data in TFRecords format, in this section, we demonstrate the process of calibrating TFTRT INT-8 model from a directory of raw JPEG images. We asume that raw images have been mounted to the directory /data/Calibration_data
.
As a rule of thumb, calibration data should be a small but representative set of images that is similar to what is expected in deployment. Empirically, for common network architectures trained on imagenet data, calibration data of size 500-1000 provide good accuracy. As such, a good strategy for a dataset such as imagenet is to choose one sample from each class.
In [ ]:
data_directory = "/data/Calibration_data"
calibration_files = [os.path.join(path, name) for path, _, files in os.walk(data_directory) for name in files]
print('There are %d calibration files. \n%s\n%s\n...'%(len(calibration_files), calibration_files[0], calibration_files[-1]))
We define a helper function to read and preprocess image from JPEG file.
In [ ]:
def parse_file(filepath):
image = tf.io.read_file(filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = vgg_preprocessing(image, 224, 224)
return image
In [ ]:
num_calibration_batches = 2
# prepare calibration dataset
dataset = tf.data.Dataset.from_tensor_slices(calibration_files)
dataset = dataset.map(map_func=parse_file, num_parallel_calls=20)
dataset = dataset.batch(batch_size=BATCH_SIZE)
dataset = dataset.repeat(None)
calibration_dataset = dataset.take(num_calibration_batches)
def calibration_input_fn():
for x in calibration_dataset:
yield (x, )
In [ ]:
# set a directory to write the saved model
INT8_SAVED_MODEL_DIR = SAVED_MODEL_DIR + "_TFTRT_INT8/2"
!rm -rf $INT8_SAVED_MODEL_DIR
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt.TrtPrecisionMode.INT8)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=SAVED_MODEL_DIR,
conversion_params=conversion_params)
converter.convert(calibration_input_fn=calibration_input_fn)
converter.save(INT8_SAVED_MODEL_DIR)
As before, we can benchmark the speed and accuracy of the resulting model.
In [ ]:
benchmark_saved_model(INT8_SAVED_MODEL_DIR)