This script is for retrieving images based on sketch query
In [2]:
import numpy as np
from pylab import *
%matplotlib inline
import os
import sys
First, we need to import caffe. You'll need to have caffe installed, as well as python interface for caffe.
In [3]:
#TODO: specify your caffe root folder here
caffe_root = "X:\caffe_siggraph/caffe-windows-master"
sys.path.insert(0, caffe_root+'/python')
import caffe
Now we can load up the network. You can change the path to your own network here. Make sure to use the matching deploy prototxt files and change the target layer to your layer name.
In [4]:
#TODO: change to your own network and deploying file
PRETRAINED_FILE = '../models/triplet_googlenet/triplet_googlenet_finegrain_final.caffemodel'
sketch_model = '../models/triplet_googlenet/googlenet_sketchdeploy.prototxt'
image_model = '../models/triplet_googlenet/googlenet_imagedeploy.prototxt'
In [ ]:
caffe.set_mode_gpu()
#caffe.set_mode_cpu()
sketch_net = caffe.Net(sketch_model, PRETRAINED_FILE, caffe.TEST)
img_net = caffe.Net(image_model, PRETRAINED_FILE, caffe.TEST)
sketch_net.blobs.keys()
In [6]:
#TODO: set output layer name. You can use sketch_net.blobs.keys() to list all layer
output_layer_sketch = 'pool5/7x7_s1_s'
output_layer_image = 'pool5/7x7_s1_p'
In [7]:
#set the transformer
transformer = caffe.io.Transformer({'data': np.shape(sketch_net.blobs['data'].data)})
transformer.set_mean('data', np.array([104, 117, 123]))
transformer.set_transpose('data',(2,0,1))
transformer.set_channel_swap('data', (2,1,0))
transformer.set_raw_scale('data', 255.0)
The following script show how to use our network to do the retrieval. The easiest way to use the script is to simply put every images you want to retrieve in one folder and modify 'photo_paths' to point to your folder. Then change 'sketch_path' to point to the sketch you want to use as a query.
In [43]:
#TODO: specify photo folder for the retrieval
photo_paths = 'C:\Users\Patsorn\Documents/notebook_backup/SBIR/retrieval/'
In [44]:
#load up images
img_list = os.listdir(photo_paths)
N = np.shape(img_list)[0]
print 'Retrieving from', N,'photos'
In [45]:
#extract feature for all images
feats = []
for i,path in enumerate(img_list):
imgname = path.split('/')[-1]
imgname = imgname.split('.jpg')[0]
imgcat = path.split('/')[0]
print '\r',str(i+1)+'/'+str(N)+ ' '+'Extracting ' +path+'...',
full_path = photo_paths + path
img = (transformer.preprocess('data', caffe.io.load_image(full_path.rstrip())))
img_in = np.reshape([img],np.shape(sketch_net.blobs['data'].data))
out_img = img_net.forward(data=img_in)
out_img = np.copy(out_img[output_layer_image])
feats.append(out_img)
print 'done',
np.shape(feats)
feats = np.resize(feats,[np.shape(feats)[0],np.shape(feats)[2]]) #quick fixed for size
In [46]:
#build nn pool
from sklearn.neighbors import NearestNeighbors,LSHForest
nbrs = NearestNeighbors(n_neighbors=np.size(feats,0), algorithm='brute',metric='cosine').fit(feats)
In [47]:
#Load up sketch query
sketch_path = "X:\data_for_research\sketch_dataset\png/giraffe/7366.png"
sketch_in = (transformer.preprocess('data', caffe.io.load_image(sketch_path)))
sketch_in = np.reshape([sketch_in],np.shape(sketch_net.blobs['data'].data))
query = sketch_net.forward(data=sketch_in)
query=np.copy(query[output_layer_sketch])
#get nn
distances, indices = nbrs.kneighbors(np.reshape(query,[np.shape(query)[1]]))
#show query
f = plt.figure(0)
plt.imshow(plt.imread(sketch_path))
plt.axis('off')
#show results
for i in range(1,5,1):
f = plt.figure(i)
img = plt.imread(photo_paths+img_list[indices[0][i-1]])
plt.imshow(img)
plt.axis('off')
plt.show(block=False)