In [15]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from dataset import MNISTDataset
from model import *

from scipy.spatial.distance import cdist
from matplotlib import gridspec


  File "cell_name", line 2
SyntaxError: import * only allowed at module level

In [7]:
dataset = MNISTDataset()
train_images = dataset.images_train[:20000]
test_images = dataset.images_test
len_test = len(test_images)
len_train = len(train_images)


===Loading MNIST Dataset===
Images train : (60000, 28, 28, 1)
Labels train : (60000, 1)
Images test  : (10000, 28, 28, 1)
Labels test  : (10000,)
Unique label : [0 1 2 3 4 5 6 7 8 9]

In [8]:
#helper function to plot image
def show_image(idxs, data):
    if type(idxs) != np.ndarray:
        idxs = np.array([idxs])
    fig = plt.figure()
    gs = gridspec.GridSpec(1,len(idxs))
    for i in range(len(idxs)):
        ax = fig.add_subplot(gs[0,i])
        ax.imshow(data[idxs[i],:,:,0])
        ax.axis('off')
    plt.show()

Create the siamese net feature extraction model


In [10]:
img_placeholder = tf.placeholder(tf.float32, [None, 28, 28, 1], name='img')
net = mnist_model(img_placeholder, reuse=False)

Restore from checkpoint and calc the features from all of train data


In [11]:
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state("model")
    saver.restore(sess, "model/model.ckpt")
    
    train_feat = sess.run(net, feed_dict={img_placeholder:train_images[:10000]})


INFO:tensorflow:Restoring parameters from model/model.ckpt

Searching for similar test images from trainset based on siamese feature


In [22]:
#generate new random test image
idx = np.random.randint(0, len_test)
im = test_images[idx]

In [21]:
#show the test image
show_image(idx, test_images)
print("This is image from id:", idx)


This is image from id: 9522

In [20]:
#run the test image through the network to get the test features
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state("model")
    saver.restore(sess, "model/model.ckpt")
    search_feat = sess.run(net, feed_dict={img_placeholder:[im]})
    
#calculate the cosine similarity and sort
dist = cdist(train_feat, search_feat, 'cosine')
rank = np.argsort(dist.ravel())

#show the top n similar image from train data
n = 7
show_image(rank[:n], train_images)
print("retrieved ids:", rank[:n])


INFO:tensorflow:Restoring parameters from model/model.ckpt
retrieved ids: [1480 4334  933 6889 7616 3579 7386]