In [4]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..'))) # To import keras_squeezenet.
from keras_squeezenet import SqueezeNet
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
from keras.preprocessing import image

In [5]:
import keras.backend as K
import matplotlib.pyplot as plt
%matplotlib inline

In [45]:
img = image.load_img('../images/cat.jpeg', target_size=(227, 227))

In [46]:
# Let's check out this nice lookin' cat!
plt.imshow(img)


Out[46]:
<matplotlib.image.AxesImage at 0x7f416bb4b2d0>

In [47]:
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

Randomly initialized weights

The weights should look random, and it should not be able to predict the correct class.


In [56]:
# Get weights that are not trained (i.e., randomly initialized)
model = SqueezeNet(weights=None)

In [57]:
W = K.eval(model.weights[0])
print(W.shape)
plt.figure(figsize=(18,12))
for idx in range(W.shape[-1]):
    plt.subplot(8,8,idx+1)
    plt.imshow(W[:,:,:,idx]); plt.colorbar()


(3, 3, 3, 64)

In [60]:
preds = model.predict(x)
print('Predicted:', decode_predictions(preds))
# These predictions should look random (since the network is not trained)!


('Predicted:', [[(u'n02110185', u'Siberian_husky', 0.0010239105), (u'n01910747', u'jellyfish', 0.0010204196), (u'n02486410', u'baboon', 0.0010175246), (u'n02643566', u'lionfish', 0.0010172399), (u'n02093754', u'Border_terrier', 0.0010149083)]])

Weights pretrained over ImageNet

Check to make sure we can get the pretrained weights. The weights should have some structure, and the classification should be pretty good, since the model is already trained.


In [21]:
# Get weights that are trained over ImageNet.
model = SqueezeNet(weights='imagenet')

In [52]:
W = K.eval(model.weights[0])
print(W.shape)
plt.figure(figsize=(18,12))
for idx in range(W.shape[-1]):
    plt.subplot(8,8,idx+1)
    plt.imshow(W[:,:,:,idx]); plt.colorbar()


(3, 3, 3, 64)

In [62]:
preds = model.predict(x)
print('Predicted:', decode_predictions(preds))
# Alright! These look good!


('Predicted:', [[(u'n02123045', u'tabby', 0.82134342), (u'n02124075', u'Egyptian_cat', 0.12180641), (u'n02123159', u'tiger_cat', 0.05682119), (u'n02127052', u'lynx', 2.2597995e-05), (u'n02129604', u'tiger', 5.1768461e-06)]])

Weights for no_top

Here we will save the weights for when include_top is False. We'll do this by popping off the last few layers to get the layers we want.


In [22]:
# This is the full model layers.
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_4 (InputLayer)             (None, 227, 227, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 113, 113, 64)  1792                                         
____________________________________________________________________________________________________
relu_conv1 (Activation)          (None, 113, 113, 64)  0                                            
____________________________________________________________________________________________________
pool1 (MaxPooling2D)             (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    1040                                         
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    2064                                         
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
pool3 (MaxPooling2D)             (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    4128                                         
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    8224                                         
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
pool5 (MaxPooling2D)             (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    12336                                        
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    18480                                        
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    24640                                        
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    32832                                        
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
drop9 (Dropout)                  (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
conv10 (Conv2D)                  (None, 13, 13, 1000)  513000                                       
____________________________________________________________________________________________________
relu_conv10 (Activation)         (None, 13, 13, 1000)  0                                            
____________________________________________________________________________________________________
global_average_pooling2d_2 (Glob (None, 1000)          0                                            
____________________________________________________________________________________________________
loss (Activation)                (None, 1000)          0                                            
====================================================================================================
Total params: 1,235,496.0
Trainable params: 1,235,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________

In [24]:
# We don't want to use the last layer, so we'll pop() off the 1) loss, 2) pooling, 3) relu, 4) conv10, 5) dropout layers.
num_layers_to_pop = 5
for idx in range(num_layers_to_pop):
    print(model.layers.pop())


<keras.layers.core.Activation object at 0x7f41b00dc590>
<keras.layers.pooling.GlobalAveragePooling2D object at 0x7f41b0111590>
<keras.layers.core.Activation object at 0x7f41b0183a10>
<keras.layers.convolutional.Conv2D object at 0x7f41b01abad0>

In [25]:
model.layers.pop()


Out[25]:
<keras.layers.core.Dropout at 0x7f41b0145dd0>

In [26]:
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_4 (InputLayer)             (None, 227, 227, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 113, 113, 64)  1792                                         
____________________________________________________________________________________________________
relu_conv1 (Activation)          (None, 113, 113, 64)  0                                            
____________________________________________________________________________________________________
pool1 (MaxPooling2D)             (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    1040                                         
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    2064                                         
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
pool3 (MaxPooling2D)             (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    4128                                         
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    8224                                         
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
pool5 (MaxPooling2D)             (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    12336                                        
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    18480                                        
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    24640                                        
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    32832                                        
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/concat (Concatenate)       (None, 13, 13, 512)   0                                            
====================================================================================================
Total params: 722,496.0
Trainable params: 722,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________

In [29]:
model.save_weights('../weights/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5')

include_top Checks

Check if include_top flag now works


In [37]:
from keras.utils import get_file

In [43]:
model = SqueezeNet(include_top=False, weights='imagenet')


Downloading data from https://github.com/jeremykawahara/keras-squeezenet/raw/master/weights/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5

In [49]:
out = model.predict(x)

In [50]:
print(out.shape)


(1, 13, 13, 512)

In [57]:
plt.imshow(np.squeeze(out[0,:,:,102])); plt.colorbar()


Out[57]:
<matplotlib.colorbar.Colorbar at 0x7f416aaaab10>

In [61]:
model = SqueezeNet(include_top=False, weights='imagenet', pooling='avg')

In [62]:
out = model.predict(x)
print(out.shape)


(1, 512)