In [ ]:
#!/usr/bin/env python

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Test a Fast R-CNN network on an image database."""
%matplotlib inline
import sys,os

from lib.fast_rcnn.test import test_net, load_test_net
from lib.fast_rcnn.config import cfg, cfg_from_file
from lib.datasets.factory import get_imdb
from lib.networks.factory import get_network
import argparse
import pprint
import time
import tensorflow as tf
from easydict import EasyDict as edict

if __name__ == '__main__':
#     args = parse_args()
    args=edict()
    args.cfg_file="./experiments/cfgs/faster_rcnn_end2end_resnet.yml"
    args.model="./output/faster_rcnn_end2end_resnet_voc/voc_2007_trainval"
    args.gpu_id=0
    args.wait=True
    args.imdb_name="voc_2007_test"
    args.comp_mode="store_true"
    args.network_name="Resnet50_test"
    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    print('Using config:')
    pprint.pprint(cfg)

    while not os.path.exists(args.model) and args.wait:
        print(('Waiting for {} to exist...'.format(args.model)))
        time.sleep(1000)

    weights_filename = os.path.splitext(os.path.basename(args.model))[0]

    imdb = get_imdb(args.imdb_name)
    imdb.competition_mode(args.comp_mode)

    device_name = '/gpu:{:d}'.format(args.gpu_id)
    print(device_name)
    with tf.device(device_name):
        network = get_network(args.network_name)
    print(('Use network `{:s}` in training'.format(args.network_name)))

    cfg.GPU_ID = args.gpu_id
    # import os
    # os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
    # os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu_id)
    # start a session
    saver = tf.train.Saver()
    c = tf.ConfigProto(allow_soft_placement=True)
    c.gpu_options.visible_device_list=str(args.gpu_id)
    sess = tf.Session(config=c)
    saver.restore(sess, tf.train.latest_checkpoint(args.model))
    print((('Loading model weights from {:s}').format(args.model)))

    test_net(sess, network, imdb, weights_filename, vis=True, thresh=0.7)
#     load_test_net enables you to work on early generated test result
#     load_test_net(sess, network, imdb, weights_filename)