In [ ]:
#!/usr/bin/env python
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Test a Fast R-CNN network on an image database."""
%matplotlib inline
import sys,os
from lib.fast_rcnn.test import test_net, load_test_net
from lib.fast_rcnn.config import cfg, cfg_from_file
from lib.datasets.factory import get_imdb
from lib.networks.factory import get_network
import argparse
import pprint
import time
import tensorflow as tf
from easydict import EasyDict as edict
if __name__ == '__main__':
# args = parse_args()
args=edict()
args.cfg_file="./experiments/cfgs/faster_rcnn_end2end_resnet.yml"
args.model="./output/faster_rcnn_end2end_resnet_voc/voc_2007_trainval"
args.gpu_id=0
args.wait=True
args.imdb_name="voc_2007_test"
args.comp_mode="store_true"
args.network_name="Resnet50_test"
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
print('Using config:')
pprint.pprint(cfg)
while not os.path.exists(args.model) and args.wait:
print(('Waiting for {} to exist...'.format(args.model)))
time.sleep(1000)
weights_filename = os.path.splitext(os.path.basename(args.model))[0]
imdb = get_imdb(args.imdb_name)
imdb.competition_mode(args.comp_mode)
device_name = '/gpu:{:d}'.format(args.gpu_id)
print(device_name)
with tf.device(device_name):
network = get_network(args.network_name)
print(('Use network `{:s}` in training'.format(args.network_name)))
cfg.GPU_ID = args.gpu_id
# import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu_id)
# start a session
saver = tf.train.Saver()
c = tf.ConfigProto(allow_soft_placement=True)
c.gpu_options.visible_device_list=str(args.gpu_id)
sess = tf.Session(config=c)
saver.restore(sess, tf.train.latest_checkpoint(args.model))
print((('Loading model weights from {:s}').format(args.model)))
test_net(sess, network, imdb, weights_filename, vis=True, thresh=0.7)
# load_test_net enables you to work on early generated test result
# load_test_net(sess, network, imdb, weights_filename)