In [8]:
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib
import numpy as np
In [9]:
def export_model_for_mobile(model_name, input_node_name, output_node_name):
# 先暫存成另一個檔檔
tf.train.write_graph(K.get_session().graph_def, 'out', \
model_name + '_graph.pbtxt')
tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')
freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
False, 'out/' + model_name + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + model_name + '.pb', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, [input_node_name], [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
In [10]:
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
model = Sequential([
Conv2D(8, (3, 3), activation='relu', input_shape=[128,128,3]),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(8, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128),
Activation('relu'),
Dense(7),
Activation('softmax')
])
model.summary()
In [11]:
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
In [ ]:
In [12]:
model.load_weights("/home/kent/git/DeepLearning_ClassmatesImageClassification_jdwang2018_5_25/CNN_Classfier_32X32_jdwang_Weight_1.h5")
In [ ]:
In [13]:
export_model_for_mobile('classmate_new', model.input.name.split(":")[0], model.output.name.split(":")[0])
In [ ]: