Model-Sizing for Keras CNN Model Zoo

This is a sanity check for :

In particular, their model comparison graph :

and this recent blog post (which came out well after this notebook was built) :

import keras
#import tensorflow.contrib.keras as keras
import numpy as np

if False:
    import os, sys

    targz = "v0.5.tar.gz"
    url = ""+targz
    models_orig_dir = 'deep-learning-models-0.5'
    models_here_dir = 'keras_deep_learning_models'
    models_dir = './models/'

    if not os.path.exists(models_dir):

    if not os.path.isfile( os.path.join(models_dir, models_here_dir, '') ):
        tarfilepath = os.path.join(models_dir, targz)
        if not os.path.isfile(tarfilepath):
            import urllib.request 
            urllib.request.urlretrieve(url, tarfilepath) 
        import tarfile, shutil, 'r:gz').extractall(models_dir)
        shutil.move(os.path.join(models_dir, models_orig_dir), os.path.join(models_dir, models_here_dir))
        if os.path.isfile( os.path.join(models_dir, models_here_dir, '') ):


    print("Keras Model Zoo model code installed")

#import keras
#if keras.__version__ < '2.0.0':
#    print("keras version = %s is too old" % (keras.__version__,))

from keras.applications.inception_v3 import decode_predictions
from keras.preprocessing import image as keras_preprocessing_image

#from keras_deep_learning_models.imagenet_utils import decode_predictions

#from tensorflow.contrib.keras.api.keras.applications.inception_v3 import decode_predictions
#from tensorflow.contrib.keras.api.keras.preprocessing import image as keras_preprocessing_image

# This call to 'decode_predictions' wiil potentially download imagenet_class_index.json (35Kb)
decode_predictions(np.zeros( (1,1000) ), top=1)

Image Loading and Pre-processing

def image_to_input(model, preprocess_input_fn, img_path):
    img = keras_preprocessing_image.load_img(img_path, target_size=target_size)
    x = keras.preprocessing.image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input_fn(x)
    return x

def test_model_sanity(model, preprocess_input_fn, img_path, img_class_str=''):
    x = image_to_input(model, preprocess_input_fn, img_path)
    preds = model.predict(x)
    predictions = decode_predictions(preds, top=1)
    if len(img_class_str)>0:
        if predictions[0][0][1] != img_class_str:
            print("INCORRECT CLASS!")
        print('Predicted:', predictions)
        # prints: [[('n02123045', 'tabby', 0.76617092)]]

img_path, img_class = './images/cat-with-tongue_224x224.jpg', 'tabby'

Model loading / timing

import time
def load_model_weights(fn, weight_set, assume_download=30):
    t0 = time.time()
    m = fn(weights=weight_set)
    if time.time()-t0>float(assume_download): # more that this => downloading, so retry to get set-up time cleanly
        print("Assume that >30secs means that we just downloaded the dataset : load again for timing")        
        t0 = time.time()
        m = fn(weights=weight_set)
    time_load = float(time.time()-t0)
    weight_count=[ float(np.sum([keras.backend.count_params(p) for p in set(w)]))/1000./1000. 
                   for w in [m.trainable_weights, m.non_trainable_weights] ]
    print("Loaded %.0fMM parameters (and %.0fk fixed parameters) into model in %.3f seconds" % 
          (weight_count[0], weight_count[1]*1000., time_load,))
    return m, time_load, weight_count[0], weight_count[1]

def time_model_predictions(model, preprocess_input_fn, img_path, batch_size=1, iters=1):
    x = image_to_input(model, preprocess_input_fn, img_path)

    batch = np.tile(x, (batch_size,1,1,1))
    t0 = time.time()
    for i in range(iters):
        _ = model.predict(batch,  batch_size=batch_size)
    single = float(time.time()-t0)*1000./iters/batch_size
    print("A single image forward pass takes %.0f ms (in batches of %d, average of %d passes)" % 
          (single, batch_size, iters,))
    return single

def total_summary(fn, preprocess_input_fn):
    model, time_setup, trainable, fixed = load_model_weights(fn, 'imagenet')
    test_model_sanity(model, preprocess_input_fn, img_path, img_class)
    time_iter_ms = time_model_predictions(model, preprocess_input_fn, img_path, batch_size=8, iters=2)
    model=None # Clean up
    return dict(name=fn.__name__, 
                params_trainable=trainable, params_fixed=fixed, 
                time_setup=time_setup, time_iter_ms=time_iter_ms)

from keras.applications.resnet50 import ResNet50, preprocess_input
#from tensorflow.contrib.keras.api.keras.applications.resnet50 import ResNet50, preprocess_input

#model_resnet50 = ResNet50(weights='imagenet')
model_resnet50,_,_,_ = load_model_weights(ResNet50, 'imagenet')

test_model_sanity(model_resnet50, preprocess_input, img_path, img_class)

_ = time_model_predictions(model_resnet50, preprocess_input, img_path, batch_size=8, iters=2)

model_resnet50=None           # release 'pointers'
keras.backend.clear_session() # release memory

Collect statistics

evaluate = ['#VGG16', '#InceptionV3', '#ResNet50', 'Xception', '#MobileNet', '#NotThisOne'] # remove the '#' to enable
#evaluate = ['VGG16', 'InceptionV3', 'ResNet50', 'Xception', 'MobileNet', '#NotThisOne']

if 'VGG16' in evaluate:
    from keras.applications.vgg16 import VGG16, preprocess_input
    #from tensorflow.contrib.keras.api.keras.applications.vgg16 import VGG16, preprocess_input
    stats_arr.append( total_summary( VGG16, preprocess_input ) )

if 'InceptionV3' in evaluate:
    from keras.applications.inception_v3 import InceptionV3, preprocess_input
    #from tensorflow.contrib.keras.api.keras.applications.inception_v3 import InceptionV3, preprocess_input
    stats_arr.append( total_summary( InceptionV3, preprocess_input ) )

if 'ResNet50' in evaluate:
    from keras.applications.resnet50 import ResNet50, preprocess_input
    #from tensorflow.contrib.keras.api.keras.applications.resnet50 import ResNet50, preprocess_input
    stats_arr.append( total_summary( ResNet50, preprocess_input ) )

if 'Xception' in evaluate:
    from keras.applications.xception import Xception, preprocess_input
    #from tensorflow.contrib.keras.api.keras.applications.xception import Xception, preprocess_input
    stats_arr.append( total_summary( Xception, preprocess_input ) )

if 'MobileNet' in evaluate:
    from keras.applications.mobilenet import MobileNet, preprocess_input
    #from tensorflow.contrib.keras.api.keras.applications.mobilenet import ResNet50, preprocess_input
    stats_arr.append( total_summary( MobileNet, preprocess_input ) )

stats = { s['name']:{ k:int(v*100)/100. for k,v in s.items() if k!='name'} 
         for s in sorted(stats_arr,key=lambda x:x['name']) }

stats_cols='params_fixed params_trainable time_iter_ms time_setup'.split()

print(' '*33, stats_cols)
for k,stat in stats.items():
    print("        '%s:statdict([%s])," % (
                    (k+"'"+' '*15)[:15], 
                    ', '.join(["%6.2f" % stat[c] for c in stats_cols]),)

Load Default Stats as a fallback

def statdict(arr):
    return dict(zip(stats_cols, arr))
d=statdict([0.03, 23.81, 642.46, 514])

if len(stats_arr)==0:
    stats_laptop_cpu={ # Updated 24-Aug-2017
    #                             'params_fixed params_trainable time_iter_ms time_setup'
        'InceptionV3'   :statdict([  0.03,  23.81,  631.63,   5.03]),
        'MobileNet'     :statdict([  0.02,   4.23,  197.90,   1.75]),
        'ResNet50'      :statdict([  0.05,  25.58,  567.64,   3.55]),
        'VGG16'         :statdict([  0.00, 138.35, 1026.74,   2.64]),
        'Xception'      :statdict([  0.05,  22.85, 1188.02,   3.10]),        
    stats_titanx={  # Updated 27-Aug-2017
    #                             'params_fixed params_trainable time_iter_ms time_setup'
        'InceptionV3'   :statdict([  0.03,  23.81, 215.10,   3.31]),
        'MobileNet'     :statdict([  0.02,   4.23,  66.68,   1.22]),
        'ResNet50'      :statdict([  0.05,  25.58, 196.02,   2.58]),
        'VGG16'         :statdict([  0.00, 138.35, 336.00,   1.01]),
        'Xception'      :statdict([  0.05,  22.85, 387.45,   1.85]),
    stats = stats_titanx

Plot Graph (v. different axes)

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.set_xlabel('image processing time (ms)')
ax.set_ylabel('# of parameters')

for name, data in stats.items():
ax.scatter(X, Y, s=R)

for name,x,y in zip(names, X, Y):
        name, xy=(x, y), xytext=(+0, 30),
        textcoords='offset points', ha='right', va='bottom',
        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
        arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'))

Summary sizing

! ls -sh ~/.keras/models/
 36K imagenet_class_index.json
 26M inception_v1_weights_tf_dim_ordering_tf_kernels.h5
 92M inception_v3_weights_tf_dim_ordering_tf_kernels.h5
 17M mobilenet_1_0_224_tf.h5
 99M resnet50_weights_tf_dim_ordering_tf_kernels.h5
528M vgg16_weights_tf_dim_ordering_tf_kernels.h5
 88M xception_weights_tf_dim_ordering_tf_kernels.h5

