In [1]:
input_fld = 'input_fld_path'
weight_file = 'kerasmodel_file_name located inside input_fld'
output_node_names_of_input_network = "pred0" #comma separated
write_graph_def_ascii_flag = True
output_node_names_of_final_network = 'output_node' #comma separated
output_graph_name = 'constant_graph_weights.pb'
In [2]:
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
output_fld = input_fld + 'tensorflow_model/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)
In [3]:
net_model = load_model(weight_file_path)
num_output = len(output_node_names_of_input_network.split(','))
pred_node_names = output_node_names_of_final_network.split(',')
pred = [None]*num_output
for i in range(num_output):
pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
In [4]:
pred
pred_node_names
Out[4]:
In [5]:
import keras
from keras.models import load_model
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.saved_model import utils
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
# from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import math_ops
# from tensorflow.python.ops import variable
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
from tensorflow.contrib.session_bundle import exporter
In [6]:
import tensorflow as tf
checkpoint_file = output_fld #+ "checkpoint"
checkpoint_state_name = "checkpoint_state"
input_graph_name = output_fld + 'only_the_graph_def.pb'
output_graph_path = os.path.join(output_fld, output_graph_name)
keras_architecture = osp.join(output_fld, 'keras_architecture_json')
keras_weights = osp.join(output_fld, 'keras_weights')
In [7]:
from keras import backend as K
sess = K.get_session()
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
saver = tf.train.Saver()
saver.save(sess, checkpoint_file, global_step=0, latest_filename='checkpoint_state')
if write_graph_def_ascii_flag:
tf.train.write_graph(input_graph_def, output_fld, 'only_the_graph_def.pb.ascii', as_text=True)
tf.train.write_graph(input_graph_def, output_fld, 'only_the_graph_def.pb', as_text=False)
writer = tf.summary.FileWriter(output_fld, graph=graph)
In [8]:
import freeze_graph
import os
K._LEARNING_PHASE = tf.constant(0)
input_saver_def_path = "" # deprecated
input_binary = True
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
clear_devices = False
freeze_graph.freeze_graph(input_graph_name, input_saver_def_path,
input_binary, checkpoint_file+'-0',
output_node_names_of_final_network, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "", None)