In [ ]:
import tensorflow as tf
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os, sys, cv2
import argparse
import os.path as osp
import glob
from easydict import EasyDict as edict


from lib.networks.factory import get_network
from lib.fast_rcnn.config import cfg
from lib.fast_rcnn.test import im_detect
from lib.fast_rcnn.nms_wrapper import nms_wrapper
from lib.utils.timer import Timer

CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')


# CLASSES = ('__background__','person','bike','motorbike','car','bus')

def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    print(dets)
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
        )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                 fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()


def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im = cv2.imread(image_name)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)

    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    CONF_THRESH = 0.7
    NMS_THRESH = 0.3
    res = nms_wrapper(scores, boxes, threshold=0.7)
#     print(res)
#     for cls_ind, cls in enumerate(CLASSES[1:]):
#         cls_ind += 1  # because we skipped background
#         cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
#         cls_scores = scores[:, cls_ind]
#         dets = np.hstack((cls_boxes,
#                           cls_scores[:, np.newaxis])).astype(np.float32)
#         keep = nms(dets, NMS_THRESH)
#         dets = dets[keep, :]
#         vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
    for ind, r in enumerate(res):
        if r['dets'] is None: continue
        dets = r['dets']
        for i in range(0, dets.shape[0]):
            vis_detections(im, r['class'], np.expand_dims(dets[i, :], 0), ax, thresh=CONF_THRESH)


if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = edict()
    args.gpu_id=0
    args.demo_net="Resnet50_test"
    args.model="./output/faster_rcnn_end2end_resnet_voc/voc_2007_trainval"
#     args = parse_args()

    if args.model == ' ' or not os.path.exists(args.model):
        print(('current path is ' + os.path.abspath(__file__)))
        raise IOError(('Error: Model not found.\n'))
    # load network
    device_name = '/gpu:{:d}'.format(args.gpu_id)
    print(device_name)
    with tf.device(device_name):
        net = get_network(args.demo_net)
    saver = tf.train.Saver()
    # init session
    c = tf.ConfigProto(allow_soft_placement=True)
    c.gpu_options.visible_device_list=str(args.gpu_id)
    sess = tf.Session(config=c)

    # load model
    print(('Loading network {:s}... '.format(args.demo_net)), end=' ')
    ckpt = tf.train.latest_checkpoint(args.model)
    if ckpt:
        # the global_step will restore sa well
        saver.restore(sess,ckpt)
        print('restore from the checkpoint{0}'.format(ckpt))
    #saver.restore(sess, args.model)
    print (' done.')

    # Warmup on a dummy image
    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = im_detect(sess, net, im)

    im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
               glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg'))
#     im_names = ["/home/antonio/tf_deformable_net/data/VOCdevkit2007/VOC2007/JPEGImages/00"+name for name \
#                             in ["7097.jpg", "8997.jpg"]]
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for {:s}'.format(im_name))
        demo(sess, net, im_name)

    plt.show()