Building an image retrieval system with deep features

Fire up GraphLab Create


In [1]:
import graphlab

Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.


In [2]:
image_train = graphlab.SFrame('image_train_data/')
image_test = graphlab.SFrame('image_test_data/')


[INFO] 1448863525 : INFO:     (initialize_globals_from_environment:282): Setting configuration variable GRAPHLAB_FILEIO_ALTERNATIVE_SSL_CERT_FILE to /Users/jrrembert/venvs/dato-env/lib/python2.7/site-packages/certifi/cacert.pem
1448863525 : INFO:     (initialize_globals_from_environment:282): Setting configuration variable GRAPHLAB_FILEIO_ALTERNATIVE_SSL_CERT_DIR to 
This non-commercial license of GraphLab Create is assigned to j.ryan.rembert@gmail.com and will expire on October 13, 2016. For commercial licensing options, visit https://dato.com/buy/.

[INFO] Start server at: ipc:///tmp/graphlab_server-71647 - Server binary: /Users/jrrembert/venvs/dato-env/lib/python2.7/site-packages/graphlab/unity_server - Server log: /tmp/graphlab_server_1448863525.log
[INFO] GraphLab Server Version: 1.7.1

Computing deep features for our images

The two lines below allow us to compute deep features. This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded.

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)


In [3]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [4]:
image_train.head()


Out[4]:
id image label deep_features image_array
24 Height: 32 Width: 32 bird [0.242871761322,
1.09545373917, 0.0, ...
[73.0, 77.0, 58.0, 71.0,
68.0, 50.0, 77.0, 69.0, ...
33 Height: 32 Width: 32 cat [0.525087952614, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[7.0, 5.0, 8.0, 7.0, 5.0,
8.0, 5.0, 4.0, 6.0, 7.0, ...
36 Height: 32 Width: 32 cat [0.566015958786, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[169.0, 122.0, 65.0,
131.0, 108.0, 75.0, ...
70 Height: 32 Width: 32 dog [1.12979578972, 0.0, 0.0,
0.778194487095, 0.0, ...
[154.0, 179.0, 152.0,
159.0, 183.0, 157.0, ...
90 Height: 32 Width: 32 bird [1.71786928177, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[216.0, 195.0, 180.0,
201.0, 178.0, 160.0, ...
97 Height: 32 Width: 32 automobile [1.57818555832, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[33.0, 44.0, 27.0, 29.0,
44.0, 31.0, 32.0, 45.0, ...
107 Height: 32 Width: 32 dog [0.0, 0.0,
0.220677852631, 0.0, ...
[97.0, 51.0, 31.0, 104.0,
58.0, 38.0, 107.0, 61.0, ...
121 Height: 32 Width: 32 bird [0.0, 0.23753464222, 0.0,
0.0, 0.0, 0.0, ...
[93.0, 96.0, 88.0, 102.0,
106.0, 97.0, 117.0, ...
136 Height: 32 Width: 32 automobile [0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 7.5737862587, 0.0, ...
[35.0, 59.0, 53.0, 36.0,
56.0, 56.0, 42.0, 62.0, ...
138 Height: 32 Width: 32 bird [0.658935725689, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[205.0, 193.0, 195.0,
200.0, 187.0, 193.0, ...
[10 rows x 5 columns]

Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.


In [5]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')


PROGRESS: Starting brute force nearest neighbors model training.

Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.


In [6]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()



In [7]:
knn_model.query(cat)


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 36.825ms     |
PROGRESS: | Done         |         | 100         | 171.746ms    |
PROGRESS: +--------------+---------+-------------+--------------+
Out[7]:
query_label reference_label distance rank
0 384 0.0 1
0 6910 36.9403137951 2
0 39777 38.4634888975 3
0 36870 39.7559623119 4
0 41734 39.7866014148 5
[5 rows x 4 columns]

We are going to create a simple function to view the nearest neighbors to save typing:


In [8]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [9]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 17.384ms     |
PROGRESS: | Done         |         | 100         | 130.046ms    |
PROGRESS: +--------------+---------+-------------+--------------+

In [10]:
cat_neighbors['image'].show()


Very cool results showing similar cats.

Finding similar images to a car


In [11]:
car = image_train[8:9]
car['image'].show()



In [12]:
get_images_from_ids(knn_model.query(car))['image'].show()


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 26.325ms     |
PROGRESS: | Done         |         | 100         | 132.668ms    |
PROGRESS: +--------------+---------+-------------+--------------+

Just for fun, let's create a lambda to find and show nearest neighbor images


In [13]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [14]:
show_neighbors(8)


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 24.623ms     |
PROGRESS: | Done         |         | 100         | 126.575ms    |
PROGRESS: +--------------+---------+-------------+--------------+

In [15]:
show_neighbors(26)


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 15.068ms     |
PROGRESS: | Done         |         | 100         | 151.14ms     |
PROGRESS: +--------------+---------+-------------+--------------+

In [16]:
auto_data = image_train[image_train['label'] == 'automobile']
cat_data = image_train[image_train['label'] == 'cat']
dog_data = image_train[image_train['label'] == 'dog']
bird_data = image_train[image_train['label'] == 'bird']

In [18]:
auto_model = graphlab.nearest_neighbors.create(auto_data,features=['deep_features'],
                                             label='id')


PROGRESS: Starting brute force nearest neighbors model training.

In [19]:
cat_model = graphlab.nearest_neighbors.create(cat_data,features=['deep_features'],
                                             label='id')


PROGRESS: Starting brute force nearest neighbors model training.

In [20]:
dog_model = graphlab.nearest_neighbors.create(dog_data,features=['deep_features'],
                                             label='id')


PROGRESS: Starting brute force nearest neighbors model training.

In [21]:
bird_model = graphlab.nearest_neighbors.create(bird_data,features=['deep_features'],
                                             label='id')


PROGRESS: Starting brute force nearest neighbors model training.

In [22]:
cat = image_test[0:1]

In [23]:
cat_model_query = cat_model.query(cat)
cat_neighbors = get_images_from_ids(cat_model_query)
#cat_neighbors
cat_neighbors['image'].show()


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 12.189ms     |
PROGRESS: | Done         |         | 100         | 49.705ms     |
PROGRESS: +--------------+---------+-------------+--------------+

In [24]:
cat_model_query['distance'].mean()


Out[24]:
36.15573070978294

In [25]:
dog_model_query = dog_model.query(cat)


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 13.98ms      |
PROGRESS: | Done         |         | 100         | 54.851ms     |
PROGRESS: +--------------+---------+-------------+--------------+

In [26]:
dog_neighbor = get_images_from_ids(dog_model_query)
dog_neighbor['image'].show()



In [ ]:
dog_model_query['distance'].mean()

In [ ]:
image_test_auto = image_test[image_test['label'] == 'automobile']
image_test_cat = image_test[image_test['label'] == 'cat']
image_test_dog = image_test[image_test['label'] == 'dog']
image_test_bird = image_test[image_test['label'] == 'bird']

In [ ]:
dog_cat_neighbors = cat_model.query(image_test_dog, k=1)
dog_dog_neighbors = dog_model.query(image_test_dog, k=1)
dog_bird_neighbors = bird_model.query(image_test_dog, k=1)
dog_auto_neighbors = auto_model.query(image_test_dog, k=1)

In [ ]:
dog_distances = graphlab.SFrame()
dog_distances['dog-dog'] = dog_dog_neighbors['distance']
dog_distances['dog-cat'] = dog_cat_neighbors['distance']
dog_distances['dog-bird'] = dog_bird_neighbors['distance']
dog_distances['dog-auto'] = dog_auto_neighbors['distance']

In [ ]:
dog_distances

In [ ]:
def is_dog_correct(row):
    for col_name in dog_distances.column_names():
        if row['dog-dog'] > row[col_name]:
            return 0
    return 1

In [ ]:
num_correct = dog_distances.apply(is_dog_correct)

In [ ]:
num_correct.sum()

In [ ]:
cat_model_query

In [ ]:
cat_model.query(cat)

In [ ]: