以下為 Export 成 freeze_graph 的範例程式嗎


In [8]:
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras import backend as K

import tensorflow as tf
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib

import numpy as np

Exprot to frezen model 的標準作法

  • 由於在 Tensorflow Lite
  • 但在使用之前需要 Session 載入 Graph & Weight
  • tf.train.write_graph :如果是使用 Keras 的話就是用 K.get_session().graph_def 取得 grpah ,然後輸到 phtxt
  • tf.train.Saver() : 透過 K.get_session() 取得 Session ,而後透過 tf.train.Saver().save()

In [9]:
def export_model_for_mobile(model_name, input_node_name, output_node_name):
    
    
    # 先暫存成另一個檔檔
    tf.train.write_graph(K.get_session().graph_def, 'out', \
        model_name + '_graph.pbtxt')

    tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
        False, 'out/' + model_name + '.chkp', output_node_name, \
        "save/restore_all", "save/Const:0", \
        'out/frozen_' + model_name + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, [input_node_name], [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

創建 Graph


In [10]:
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K

model = Sequential([

    Conv2D(8, (3, 3), activation='relu', input_shape=[128,128,3]),
    MaxPooling2D(pool_size=(2, 2)),
    Conv2D(8, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(128),
    Activation('relu'),
    Dense(7),
    Activation('softmax') 
])
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 126, 126, 8)       224       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 63, 63, 8)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 61, 61, 8)         584       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 30, 30, 8)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 7200)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               921728    
_________________________________________________________________
activation_1 (Activation)    (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 7)                 903       
_________________________________________________________________
activation_2 (Activation)    (None, 7)                 0         
=================================================================
Total params: 923,439
Trainable params: 923,439
Non-trainable params: 0
_________________________________________________________________

In [11]:
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

In [ ]:


In [12]:
model.load_weights("/home/kent/git/DeepLearning_ClassmatesImageClassification_jdwang2018_5_25/CNN_Classfier_32X32_jdwang_Weight_1.h5")

In [ ]:

呼叫預設的函式


In [13]:
export_model_for_mobile('classmate_new', model.input.name.split(":")[0], model.output.name.split(":")[0])


INFO:tensorflow:Restoring parameters from out/classmate_new.chkp
INFO:tensorflow:Froze 8 variables.
Converted 8 variables to const ops.

In [ ]: