2. 特征提取_从VGG16到InceptionResNetV2

导入包


In [1]:
import h5py
import os
import time
import numpy as np

from keras.layers import *
from keras.models import *
from keras.applications import *
from keras.optimizers import *
from keras.regularizers import *
from keras.preprocessing.image import *
from keras.applications.inception_v3 import preprocess_input


Using TensorFlow backend.

使用预训练权重的VGG16、VGG19、ResNet50、Xception和InceptionV3模型提取特征


In [2]:
def get_features(MODEL, image_size, date_str, lambda_func=None, batch_size=1):
    print('{0} start.'.format(MODEL.__name__))
    start_time = time.time()
    
    width = image_size[0]
    height = image_size[1]
    input_tensor = Input((height, width, 3))
    x = input_tensor
    if lambda_func:
        print(lambda_func.__name__)
        x = Lambda(lambda_func)(x)
    base_model = MODEL(input_tensor=x, weights='imagenet', input_shape=(height, width, 3), include_top=False)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))

    cwd = os.getcwd()
    data_train_path = os.path.join(cwd, 'input', 'data_train')
    data_val_path = os.path.join(cwd, 'input', 'data_val')
    data_test_path  = os.path.join(cwd, 'input', 'data_test')
    
    gen = ImageDataGenerator()
#     gen = ImageDataGenerator(zoom_range = 0.1,
#                             height_shift_range = 0.1,
#                             width_shift_range = 0.1,
#                             rotation_range = 10)
    train_generator = gen.flow_from_directory(data_train_path, image_size, shuffle=False, 
                                              batch_size=batch_size)
    val_generator = gen.flow_from_directory(data_val_path, image_size, shuffle=False, 
                                              batch_size=batch_size)
    test_generator  = gen.flow_from_directory(data_test_path,  image_size, shuffle=False, 
                                              batch_size=batch_size)
    
    print(len(train_generator.filenames))
    print(len(test_generator.filenames))
    print('train_generator')
    train = model.predict_generator(train_generator, verbose=1, steps=len(train_generator.filenames))
    print('val_generator')
    val = model.predict_generator(val_generator, verbose=1, steps=len(val_generator.filenames))
    print('test_generator')
    test = model.predict_generator(test_generator, verbose=1, steps=len(test_generator.filenames))
    
#     print('train_generator')
#     train = model.predict_generator(train_generator, verbose=1, steps=10)
#     print('test_generator')
#     test = model.predict_generator(test_generator, verbose=1, steps=10)

    folder_path = os.path.join(cwd, 'model')
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
    file_name = os.path.join(cwd, 'model', 'feature_{0}_{1}.h5'.format(MODEL.__name__, date_str))
    print(file_name)
    if os.path.exists(file_name):
        os.remove(file_name)
        
    with h5py.File(file_name) as h:
        h.create_dataset("train", data=train)
        h.create_dataset("train_labels", data=train_generator.classes)
        h.create_dataset("val", data=val)
        h.create_dataset("val_labels", data=val_generator.classes)
        h.create_dataset("test", data=test)
    
    print(train.shape)
    print(train_generator.classes)
    print(test.shape)
    
    end_time = time.time()
    print('Spend time: {0} s'.format(end_time-start_time))

In [3]:
# Get date str
# date_str = time.strftime("%Y%m%d", time.localtime())
date_str = '20180223'
print(date_str)


20180223

In [4]:
# get_features(VGG16, (224, 224), date_str)

In [5]:
# get_features(VGG19, (224, 224), date_str)

In [6]:
# get_features(ResNet50, (224, 224), date_str)

In [7]:
# get_features(MobileNet, (224, 224), date_str)

In [8]:
get_features(Xception, (299, 299), date_str, xception.preprocess_input)


Xception start.
preprocess_input
Found 9710 images belonging to 120 classes.
Found 512 images belonging to 120 classes.
Found 10357 images belonging to 1 classes.
9710
10357
train_generator
9710/9710 [==============================] - 1639s 169ms/step
val_generator
512/512 [==============================] - 87s 169ms/step
test_generator
10357/10357 [==============================] - 1749s 169ms/step
D:\Udacity\MachineLearning(Advanced)\p6_graduation_project\model\feature_Xception_20180223.h5
(9710, 2048)
[  0   0   0 ..., 119 119 119]
(10357, 2048)
Spend time: 3485.1917214393616 s

In [9]:
get_features(InceptionV3, (299, 299), date_str, inception_v3.preprocess_input)


InceptionV3 start.
preprocess_input
Found 9710 images belonging to 120 classes.
Found 512 images belonging to 120 classes.
Found 10357 images belonging to 1 classes.
9710
10357
train_generator
9710/9710 [==============================] - 1576s 162ms/step
val_generator
512/512 [==============================] - 83s 162ms/step
test_generator
10357/10357 [==============================] - 1695s 164ms/step
D:\Udacity\MachineLearning(Advanced)\p6_graduation_project\model\feature_InceptionV3_20180223.h5
(9710, 2048)
[  0   0   0 ..., 119 119 119]
(10357, 2048)
Spend time: 3372.2834899425507 s

In [10]:
get_features(InceptionResNetV2, (299, 299), date_str, inception_resnet_v2.preprocess_input)


InceptionResNetV2 start.
preprocess_input
Found 9710 images belonging to 120 classes.
Found 512 images belonging to 120 classes.
Found 10357 images belonging to 1 classes.
9710
10357
train_generator
9710/9710 [==============================] - 2808s 289ms/step
val_generator
512/512 [==============================] - 147s 288ms/step
test_generator
10357/10357 [==============================] - 2985s 288ms/step
D:\Udacity\MachineLearning(Advanced)\p6_graduation_project\model\feature_InceptionResNetV2_20180223.h5
(9710, 1536)
[  0   0   0 ..., 119 119 119]
(10357, 1536)
Spend time: 5973.684639215469 s

In [11]:
print('Done !')


Done !

In [ ]: