In [1]:
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
data = pd.read_csv('FDDB2XYWH.csv')
data = data.drop('Unnamed: 0', 1)
#data['File_Path'] = './VOCdevkit2007/VOC2007/JPEGImages/' + data['Frame']
#data = data[(data['label'] == 0)].reset_index()
print(data.head())
In [3]:
import batch_generate
i_line = np.random.randint(len(data))
name_str, img, bb_boxes = batch_generate.get_img_by_name(data, i_line, size = (960, 640),dataset = 'FDDB')
copy_img = img
print(bb_boxes)
gta = batch_generate.bbox_transform(bb_boxes)
print(gta)
plt.figure(figsize=(10,10))
plt.imshow(img)
currentAxis = plt.gca()
for i in range(len(gta)):
currentAxis.add_patch(plt.Rectangle((gta[i,0], gta[i,1]), gta[i,2]-gta[i,0], gta[i,3]-gta[i,1], fill=False, edgecolor= 'r', linewidth=1))
In [4]:
#Inference
import config
from netarch import *
img_channel_mean = [103.939, 116.779, 123.68]
with tf.Graph().as_default():
mc = config.model_parameters()
mc.LOAD_PRETRAINED_MODEL = False
model = ResNet50(mc, '0')
saver = tf.train.Saver(model.model_params)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
saver.restore(sess, './tf_detection/model.ckpt-2000')
img = img.astype(np.float32)
img[:, :, 0] -= img_channel_mean[0]
img[:, :, 1] -= img_channel_mean[1]
img[:, :, 2] -= img_channel_mean[2]
#img_per_batch = np.expand_dims(img, axis = 0)
det_probs, det_boxes = sess.run([model.det_probs, model.det_boxes],feed_dict={model.image_input:[img], model.keep_prob: 1.0})
In [5]:
print(det_probs.shape, det_boxes.shape)
box_probs = np.reshape(det_probs[0],[-1,2])[:,1]
box_delta = np.reshape(det_boxes[0],[21600,4])
print(box_probs.shape, box_delta.shape)
In [8]:
import utils
anchor_box = mc.ANCHOR_BOX
pred_box_xyxy = utils.bbox_delta_convert_inv(anchor_box, box_delta)
box_nms, probs_nms = utils.non_max_suppression_fast(pred_box_xyxy, box_probs, 5, overlap_thresh=0.5)
In [9]:
box = box_nms
plt.figure(figsize=(10,10))
plt.imshow(copy_img)
currentAxis = plt.gca()
for i in range(len(box)):
currentAxis.add_patch(plt.Rectangle((box[i,0], box[i,1]), box[i,2]-box[i,0], box[i,3]-box[i,1], fill=False, edgecolor= 'r', linewidth=1))
In [ ]: