In [4]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..'))) # To import keras_squeezenet.
from keras_squeezenet import SqueezeNet
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
from keras.preprocessing import image
In [5]:
import keras.backend as K
import matplotlib.pyplot as plt
%matplotlib inline
In [45]:
img = image.load_img('../images/cat.jpeg', target_size=(227, 227))
In [46]:
# Let's check out this nice lookin' cat!
plt.imshow(img)
Out[46]:
<matplotlib.image.AxesImage at 0x7f416bb4b2d0>
In [47]:
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
The weights should look random, and it should not be able to predict the correct class.
In [56]:
# Get weights that are not trained (i.e., randomly initialized)
model = SqueezeNet(weights=None)
In [57]:
W = K.eval(model.weights[0])
print(W.shape)
plt.figure(figsize=(18,12))
for idx in range(W.shape[-1]):
plt.subplot(8,8,idx+1)
plt.imshow(W[:,:,:,idx]); plt.colorbar()
(3, 3, 3, 64)
In [60]:
preds = model.predict(x)
print('Predicted:', decode_predictions(preds))
# These predictions should look random (since the network is not trained)!
('Predicted:', [[(u'n02110185', u'Siberian_husky', 0.0010239105), (u'n01910747', u'jellyfish', 0.0010204196), (u'n02486410', u'baboon', 0.0010175246), (u'n02643566', u'lionfish', 0.0010172399), (u'n02093754', u'Border_terrier', 0.0010149083)]])
Check to make sure we can get the pretrained weights. The weights should have some structure, and the classification should be pretty good, since the model is already trained.
In [21]:
# Get weights that are trained over ImageNet.
model = SqueezeNet(weights='imagenet')
In [52]:
W = K.eval(model.weights[0])
print(W.shape)
plt.figure(figsize=(18,12))
for idx in range(W.shape[-1]):
plt.subplot(8,8,idx+1)
plt.imshow(W[:,:,:,idx]); plt.colorbar()
(3, 3, 3, 64)
In [62]:
preds = model.predict(x)
print('Predicted:', decode_predictions(preds))
# Alright! These look good!
('Predicted:', [[(u'n02123045', u'tabby', 0.82134342), (u'n02124075', u'Egyptian_cat', 0.12180641), (u'n02123159', u'tiger_cat', 0.05682119), (u'n02127052', u'lynx', 2.2597995e-05), (u'n02129604', u'tiger', 5.1768461e-06)]])
Here we will save the weights for when include_top is False. We'll do this by popping off the last few layers to get the layers we want.
In [22]:
# This is the full model layers.
model.summary()
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_4 (InputLayer) (None, 227, 227, 3) 0
____________________________________________________________________________________________________
conv1 (Conv2D) (None, 113, 113, 64) 1792
____________________________________________________________________________________________________
relu_conv1 (Activation) (None, 113, 113, 64) 0
____________________________________________________________________________________________________
pool1 (MaxPooling2D) (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D) (None, 56, 56, 16) 1040
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16) 0
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D) (None, 56, 56, 64) 1088
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D) (None, 56, 56, 64) 9280
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/concat (Concatenate) (None, 56, 56, 128) 0
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D) (None, 56, 56, 16) 2064
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16) 0
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D) (None, 56, 56, 64) 1088
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D) (None, 56, 56, 64) 9280
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire3/concat (Concatenate) (None, 56, 56, 128) 0
____________________________________________________________________________________________________
pool3 (MaxPooling2D) (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D) (None, 27, 27, 32) 4128
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32) 0
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D) (None, 27, 27, 128) 4224
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D) (None, 27, 27, 128) 36992
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/concat (Concatenate) (None, 27, 27, 256) 0
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D) (None, 27, 27, 32) 8224
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32) 0
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D) (None, 27, 27, 128) 4224
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D) (None, 27, 27, 128) 36992
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire5/concat (Concatenate) (None, 27, 27, 256) 0
____________________________________________________________________________________________________
pool5 (MaxPooling2D) (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D) (None, 13, 13, 48) 12336
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48) 0
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D) (None, 13, 13, 192) 9408
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D) (None, 13, 13, 192) 83136
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire6/concat (Concatenate) (None, 13, 13, 384) 0
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D) (None, 13, 13, 48) 18480
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48) 0
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D) (None, 13, 13, 192) 9408
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D) (None, 13, 13, 192) 83136
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire7/concat (Concatenate) (None, 13, 13, 384) 0
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D) (None, 13, 13, 64) 24640
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64) 0
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D) (None, 13, 13, 256) 16640
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D) (None, 13, 13, 256) 147712
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire8/concat (Concatenate) (None, 13, 13, 512) 0
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D) (None, 13, 13, 64) 32832
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64) 0
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D) (None, 13, 13, 256) 16640
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D) (None, 13, 13, 256) 147712
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire9/concat (Concatenate) (None, 13, 13, 512) 0
____________________________________________________________________________________________________
drop9 (Dropout) (None, 13, 13, 512) 0
____________________________________________________________________________________________________
conv10 (Conv2D) (None, 13, 13, 1000) 513000
____________________________________________________________________________________________________
relu_conv10 (Activation) (None, 13, 13, 1000) 0
____________________________________________________________________________________________________
global_average_pooling2d_2 (Glob (None, 1000) 0
____________________________________________________________________________________________________
loss (Activation) (None, 1000) 0
====================================================================================================
Total params: 1,235,496.0
Trainable params: 1,235,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________
In [24]:
# We don't want to use the last layer, so we'll pop() off the 1) loss, 2) pooling, 3) relu, 4) conv10, 5) dropout layers.
num_layers_to_pop = 5
for idx in range(num_layers_to_pop):
print(model.layers.pop())
<keras.layers.core.Activation object at 0x7f41b00dc590>
<keras.layers.pooling.GlobalAveragePooling2D object at 0x7f41b0111590>
<keras.layers.core.Activation object at 0x7f41b0183a10>
<keras.layers.convolutional.Conv2D object at 0x7f41b01abad0>
In [25]:
model.layers.pop()
Out[25]:
<keras.layers.core.Dropout at 0x7f41b0145dd0>
In [26]:
model.summary()
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_4 (InputLayer) (None, 227, 227, 3) 0
____________________________________________________________________________________________________
conv1 (Conv2D) (None, 113, 113, 64) 1792
____________________________________________________________________________________________________
relu_conv1 (Activation) (None, 113, 113, 64) 0
____________________________________________________________________________________________________
pool1 (MaxPooling2D) (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D) (None, 56, 56, 16) 1040
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16) 0
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D) (None, 56, 56, 64) 1088
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D) (None, 56, 56, 64) 9280
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire2/concat (Concatenate) (None, 56, 56, 128) 0
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D) (None, 56, 56, 16) 2064
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16) 0
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D) (None, 56, 56, 64) 1088
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D) (None, 56, 56, 64) 9280
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64) 0
____________________________________________________________________________________________________
fire3/concat (Concatenate) (None, 56, 56, 128) 0
____________________________________________________________________________________________________
pool3 (MaxPooling2D) (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D) (None, 27, 27, 32) 4128
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32) 0
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D) (None, 27, 27, 128) 4224
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D) (None, 27, 27, 128) 36992
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire4/concat (Concatenate) (None, 27, 27, 256) 0
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D) (None, 27, 27, 32) 8224
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32) 0
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D) (None, 27, 27, 128) 4224
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D) (None, 27, 27, 128) 36992
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128) 0
____________________________________________________________________________________________________
fire5/concat (Concatenate) (None, 27, 27, 256) 0
____________________________________________________________________________________________________
pool5 (MaxPooling2D) (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D) (None, 13, 13, 48) 12336
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48) 0
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D) (None, 13, 13, 192) 9408
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D) (None, 13, 13, 192) 83136
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire6/concat (Concatenate) (None, 13, 13, 384) 0
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D) (None, 13, 13, 48) 18480
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48) 0
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D) (None, 13, 13, 192) 9408
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D) (None, 13, 13, 192) 83136
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192) 0
____________________________________________________________________________________________________
fire7/concat (Concatenate) (None, 13, 13, 384) 0
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D) (None, 13, 13, 64) 24640
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64) 0
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D) (None, 13, 13, 256) 16640
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D) (None, 13, 13, 256) 147712
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire8/concat (Concatenate) (None, 13, 13, 512) 0
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D) (None, 13, 13, 64) 32832
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64) 0
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D) (None, 13, 13, 256) 16640
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D) (None, 13, 13, 256) 147712
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256) 0
____________________________________________________________________________________________________
fire9/concat (Concatenate) (None, 13, 13, 512) 0
====================================================================================================
Total params: 722,496.0
Trainable params: 722,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________
In [29]:
model.save_weights('../weights/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5')
Check if include_top flag now works
In [37]:
from keras.utils import get_file
In [43]:
model = SqueezeNet(include_top=False, weights='imagenet')
Downloading data from https://github.com/jeremykawahara/keras-squeezenet/raw/master/weights/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5
In [49]:
out = model.predict(x)
In [50]:
print(out.shape)
(1, 13, 13, 512)
In [57]:
plt.imshow(np.squeeze(out[0,:,:,102])); plt.colorbar()
Out[57]:
<matplotlib.colorbar.Colorbar at 0x7f416aaaab10>
In [61]:
model = SqueezeNet(include_top=False, weights='imagenet', pooling='avg')
In [62]:
out = model.predict(x)
print(out.shape)
(1, 512)
Content source: rcmalli/keras-squeezenet
Similar notebooks: