Building an image retrieval system with deep features

Fire up GraphLab Create


In [1]:
import graphlab


A newer version of GraphLab Create (v2.0.1) is available! Your current version is v2.0.

You can use pip to upgrade the graphlab-create package. For more information see https://turi.com/products/create/upgrade.

Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.


In [2]:
image_train = graphlab.SFrame('image_train_data/')


[INFO] graphlab.cython.cy_server: GraphLab Create v2.0 started. Logging: /tmp/graphlab_server_1468700145.log
This non-commercial license of GraphLab Create is assigned to gonadarush@gmail.com and will expire on July 07, 2017. For commercial licensing options, visit https://turi.com/buy/.

Computing deep features for our images

The two lines below allow us to compute deep features. This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded.

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)


In [3]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [4]:
image_train.head()


Out[4]:
id image label deep_features image_array
24 Height: 32 Width: 32 bird [0.242871761322,
1.09545373917, 0.0, ...
[73.0, 77.0, 58.0, 71.0,
68.0, 50.0, 77.0, 69.0, ...
33 Height: 32 Width: 32 cat [0.525087952614, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[7.0, 5.0, 8.0, 7.0, 5.0,
8.0, 5.0, 4.0, 6.0, 7.0, ...
36 Height: 32 Width: 32 cat [0.566015958786, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[169.0, 122.0, 65.0,
131.0, 108.0, 75.0, ...
70 Height: 32 Width: 32 dog [1.12979578972, 0.0, 0.0,
0.778194487095, 0.0, ...
[154.0, 179.0, 152.0,
159.0, 183.0, 157.0, ...
90 Height: 32 Width: 32 bird [1.71786928177, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[216.0, 195.0, 180.0,
201.0, 178.0, 160.0, ...
97 Height: 32 Width: 32 automobile [1.57818555832, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[33.0, 44.0, 27.0, 29.0,
44.0, 31.0, 32.0, 45.0, ...
107 Height: 32 Width: 32 dog [0.0, 0.0,
0.220677852631, 0.0, ...
[97.0, 51.0, 31.0, 104.0,
58.0, 38.0, 107.0, 61.0, ...
121 Height: 32 Width: 32 bird [0.0, 0.23753464222, 0.0,
0.0, 0.0, 0.0, ...
[93.0, 96.0, 88.0, 102.0,
106.0, 97.0, 117.0, ...
136 Height: 32 Width: 32 automobile [0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 7.5737862587, 0.0, ...
[35.0, 59.0, 53.0, 36.0,
56.0, 56.0, 42.0, 62.0, ...
138 Height: 32 Width: 32 bird [0.658935725689, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[205.0, 193.0, 195.0,
200.0, 187.0, 193.0, ...
[10 rows x 5 columns]

Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.


In [5]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')


Starting brute force nearest neighbors model training.

Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.


In [6]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()



In [7]:
knn_model.query(cat)


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 15.676ms     |
| Done         |         | 100         | 131.55ms     |
+--------------+---------+-------------+--------------+
Out[7]:
query_label reference_label distance rank
0 384 0.0 1
0 6910 36.9403137951 2
0 39777 38.4634888975 3
0 36870 39.7559623119 4
0 41734 39.7866014148 5
[5 rows x 4 columns]

We are going to create a simple function to view the nearest neighbors to save typing:


In [8]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [9]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 14.268ms     |
| Done         |         | 100         | 129.585ms    |
+--------------+---------+-------------+--------------+

In [10]:
cat_neighbors['image'].show()


Very cool results showing similar cats.

Finding similar images to a car


In [11]:
car = image_train[8:9]
car['image'].show()



In [12]:
get_images_from_ids(knn_model.query(car))['image'].show()


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 8.127ms      |
| Done         |         | 100         | 134.004ms    |
+--------------+---------+-------------+--------------+

Just for fun, let's create a lambda to find and show nearest neighbor images


In [13]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [14]:
show_neighbors(8)


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 8.451ms      |
| Done         |         | 100         | 126.433ms    |
+--------------+---------+-------------+--------------+

In [15]:
show_neighbors(26)


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 9.656ms      |
| Done         |         | 100         | 125.294ms    |
+--------------+---------+-------------+--------------+

In [17]:
image_train['label'].sketch_summary()


Out[17]:
+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+

In [19]:
show_neighbors(0)


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.0498753   | 10.619ms     |
| Done         |         | 100         | 125.363ms    |
+--------------+---------+-------------+--------------+

In [20]:
cat_images = image_train[image_train['label']=='cat']

In [23]:
dog_images = image_train[image_train['label']=='dog']

In [24]:
car_images = image_train[image_train['label']=='automobile']

In [27]:
bird_images = image_train[image_train['label']=='bird']

In [32]:
bird_images.head()


Out[32]:
id image label deep_features image_array
24 Height: 32 Width: 32 bird [0.242871761322,
1.09545373917, 0.0, ...
[73.0, 77.0, 58.0, 71.0,
68.0, 50.0, 77.0, 69.0, ...
90 Height: 32 Width: 32 bird [1.71786928177, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[216.0, 195.0, 180.0,
201.0, 178.0, 160.0, ...
121 Height: 32 Width: 32 bird [0.0, 0.23753464222, 0.0,
0.0, 0.0, 0.0, ...
[93.0, 96.0, 88.0, 102.0,
106.0, 97.0, 117.0, ...
138 Height: 32 Width: 32 bird [0.658935725689, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[205.0, 193.0, 195.0,
200.0, 187.0, 193.0, ...
335 Height: 32 Width: 32 bird [0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 8.50706672668, 0.0, ...
[160.0, 159.0, 154.0,
162.0, 161.0, 156.0, ...
560 Height: 32 Width: 32 bird [1.69159495831, 0.0, 0.0,
0.0, 0.0, 0.0, ...
[147.0, 138.0, 88.0,
151.0, 142.0, 92.0, ...
649 Height: 32 Width: 32 bird [0.511156201363,
0.324165046215, 0.0, ...
[65.0, 127.0, 9.0, 127.0,
160.0, 15.0, 159.0, ...
775 Height: 32 Width: 32 bird [0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 10.0127315521, 0.0, ...
[29.0, 41.0, 25.0, 29.0,
42.0, 25.0, 28.0, 41.0, ...
802 Height: 32 Width: 32 bird [0.277166724205, 0.0,
0.0, 0.0, 0.0, 0.0, ...
[233.0, 230.0, 173.0,
222.0, 218.0, 168.0, ...
975 Height: 32 Width: 32 bird [0.0, 0.0336718559265,
0.0, 0.645326733589, ...
[59.0, 180.0, 110.0,
88.0, 186.0, 117.0, ...
[10 rows x 5 columns]


In [34]:
cat_model = graphlab.nearest_neighbors.create(cat_images,features=['deep_features'], label='id')


Starting brute force nearest neighbors model training.

In [47]:
dog_model = graphlab.nearest_neighbors.create(dog_images,features=['deep_features'], label='id')


Starting brute force nearest neighbors model training.

In [48]:
car_model = graphlab.nearest_neighbors.create(car_images,features=['deep_features'], label='id')


Starting brute force nearest neighbors model training.

In [49]:
bird_model = graphlab.nearest_neighbors.create(bird_images,features=['deep_features'], label='id')


Starting brute force nearest neighbors model training.

In [38]:
image_test = graphlab.SFrame('image_test_data/')

In [55]:
cat_model.query(image_test[0:1])[0:5]['distance'].mean()


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.196464    | 7.217ms      |
| Done         |         | 100         | 44.893ms     |
+--------------+---------+-------------+--------------+
Out[55]:
36.15573070978294

In [56]:
dog_model.query(image_test[0:1])[0:5]['distance'].mean()


Starting pairwise querying.
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 0            | 1       | 0.196464    | 9.059ms      |
| Done         |         | 100         | 46.572ms     |
+--------------+---------+-------------+--------------+
Out[56]:
37.77071136184157

In [52]:
dog_images.filter_by(16976, 'id').show()



In [57]:
dog_test_images = image_test.filter_by('dog', 'label')

In [59]:
dog_cat = cat_model.query(dog_test_images, k=1)


Starting blockwise querying.
max rows per data block: 4348
number of reference data blocks: 8
number of query data blocks: 1
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 1000         | 64000   | 12.5737     | 293.062ms    |
| Done         | 509000  | 100         | 331.722ms    |
+--------------+---------+-------------+--------------+

In [61]:
dog_dog = dog_model.query(dog_test_images, k=1)


Starting blockwise querying.
max rows per data block: 4348
number of reference data blocks: 8
number of query data blocks: 1
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 1000         | 63000   | 12.3772     | 273.57ms     |
| Done         | 509000  | 100         | 296.398ms    |
+--------------+---------+-------------+--------------+

In [62]:
dog_car = car_model.query(dog_test_images, k=1)


Starting blockwise querying.
max rows per data block: 4348
number of reference data blocks: 8
number of query data blocks: 1
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 1000         | 63000   | 12.3772     | 262.279ms    |
| Done         | 509000  | 100         | 301.378ms    |
+--------------+---------+-------------+--------------+

In [63]:
dog_bird = bird_model.query(dog_test_images, k=1)


Starting blockwise querying.
max rows per data block: 4348
number of reference data blocks: 8
number of query data blocks: 1
+--------------+---------+-------------+--------------+
| Query points | # Pairs | % Complete. | Elapsed Time |
+--------------+---------+-------------+--------------+
| 1000         | 59000   | 12.3431     | 264.303ms    |
| Done         | 478000  | 100         | 302.108ms    |
+--------------+---------+-------------+--------------+

In [64]:
dog_distances = graphlab.SFrame({'dog':dog_dog['distance'], 'cat':dog_cat['distance'], 'car':dog_car['distance'], 'bird':dog_bird['distance']})

In [65]:
dog_distances.head()


Out[65]:
bird car cat dog
41.7538647304 41.9579761457 36.4196077068 33.4773590373
41.3382958925 46.0021331807 38.8353268874 32.8458495684
38.6157590853 42.9462290692 36.9763410854 35.0397073189
37.0892269954 41.6866060048 34.5750072914 33.9010327697
38.272288694 39.2269664935 34.778824791 37.4849250909
39.1462089236 40.5845117698 35.1171578292 34.945165344
40.523040106 45.1067352961 40.6095830913 39.0957278345
38.1947918393 41.3221140974 39.9036867306 37.7696131032
40.1567131661 41.8244654995 38.0674700168 35.1089144603
45.5597962603 45.4976929401 42.7258732951 43.2422832585
[10 rows x 4 columns]


In [75]:
dog_distances.apply(lambda row: row['dog'] == min(row.values())).sum()


Out[75]:
678

In [69]:
min(dog_distances[0].values())


Out[69]:
33.47735903726335

In [ ]: