In [0]:
from IPython import display
!git clone https://github.com/tensorflow/tpu
display.clear_output()
In [42]:
from __future__ import print_function
checkpoint_name = 'mnasnet-a1' #@param
url = 'https://storage.googleapis.com/mnasnet/checkpoints/' + checkpoint_name + '.tar.gz'
print('Downloading from ', url)
!wget {url}
print('Unpacking')
!tar -xvf {checkpoint_name}.tar.gz
display.clear_output()
print('Successfully downloaded checkpoint from ', url,
'. It is available as', checkpoint_name)
In [43]:
!wget https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG -O panda.jpg
In [0]:
# setup path
import sys
sys.path.append('/content/tpu/models/official/mnasnet')
sys.path.append('/content/tpu/models/common')
In [45]:
from IPython import display
import pylab
import PIL
import numpy as np
filename = 'panda.jpg'
display.display(display.Image(filename))
img = np.array(PIL.Image.open(filename).resize((224, 224))).astype(np.float)
In [46]:
import os
import tensorflow as tf
checkpoint_name = 'mnasnet-a1'
export_dir = os.path.join(checkpoint_name, 'saved_model')
serv_sess = tf.Session(graph=tf.Graph())
meta_graph_def = tf.saved_model.loader.load(serv_sess, [tf.saved_model.tag_constants.SERVING], export_dir)
In [0]:
# Checks the saved model signatures.
signature = 'serving_default'
print('Serving Signature: ', signature)
print(meta_graph_def.signature_def[signature])
In [47]:
import imagenet
top_class, probs = serv_sess.run(fetches=["ArgMax:0", "softmax_tensor:0"], feed_dict={"Placeholder:0": [img]})
print("Top class: ", top_class[0], " with Probability= ", probs[0][top_class[0]])
label_map = imagenet.create_readable_names_for_imagenet_labels()
for idx, label_id in enumerate(reversed(list(np.argsort(probs)[0][-5:]))):
print("Top %d Prediction: %d, %s, probs=%f" % (idx+1, label_id, label_map[label_id], probs[0][label_id]))