In [ ]:
import sys
import threading
from grpc.beta import implementations
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
In [ ]:
server = "tf-service:9000"
class _ResultCounter(object):
"""Counter for the prediction results."""
def __init__(self, num_tests, concurrency):
self._num_tests = 100
self._concurrency = 1
self._error = 0
self._done = 0
self._active = 0
self._condition = threading.Condition()
self._workdir = "/tmp"
def inc_error(self):
with self._condition:
self._error += 1
def inc_done(self):
with self._condition:
self._done += 1
self._condition.notify()
def dec_active(self):
with self._condition:
self._active -= 1
self._condition.notify()
def get_error_rate(self):
with self._condition:
while self._done != self._num_tests:
self._condition.wait()
return self._error / float(self._num_tests)
def throttle(self):
with self._condition:
while self._active == self._concurrency:
self._condition.wait()
self._active += 1
def _create_rpc_callback(label, result_counter):
"""Creates RPC callback function.
Args:
label: The correct label for the predicted example.
result_counter: Counter for the prediction result.
Returns:
The callback function.
"""
def _callback(result_future):
"""Callback function.
Calculates the statistics for the prediction result.
Args:
result_future: Result future of the RPC.
"""
exception = result_future.exception()
if exception:
result_counter.inc_error()
print(exception)
else:
sys.stdout.write('.')
sys.stdout.flush()
response = numpy.array(
result_future.result().outputs['scores'].float_val)
prediction = numpy.argmax(response)
if label.argmax() != prediction:
result_counter.inc_error()
result_counter.inc_done()
result_counter.dec_active()
return _callback
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
Args:
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
Returns:
The classification error rate.
Raises:
IOError: An error occurred processing test data set.
"""
test_data_set = mnist.test
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for _ in range(num_tests):
request = predict_pb2.PredictRequest()
request.model_spec.name = 'default'
request.model_spec.signature_name = 'predict_images'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
result_counter.throttle()
result_future = stub.Predict.future(request, 5.0) # 5 seconds
result_future.add_done_callback(
_create_rpc_callback(label[0], result_counter))
return result_counter.get_error_rate()
error_rate = do_inference(server, "/tmp", 1, 100)
print('\nInference error rate: %s%%' % (error_rate * 100))
In [ ]: