In [1]:
input_fld = '../data'
weight_file = 'weights.best.inceptionv3_finetune.hdf5'
num_output = 1
write_graph_def_ascii_flag = True
prefix_output_node_names_of_final_network = 'softmax'
In [2]:
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
weight_file_path = osp.join(input_fld, weight_file)
In [3]:
net_model = load_model(weight_file_path)
K.set_learning_phase(0)
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)
pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
In [4]:
sess = K.get_session()
In [5]:
gd = sess.graph.as_graph_def()
nodes_with_switch_op = [x for x in gd.node if x.op.lower().find('switch') != -1]
In [6]:
# nodes to reroute switch input
for n in [x for x in gd.node]:
ints = [i for i in n.input]
endswith1 = [ii for ii in ints if ii.endswith(':1')] #and 'Switch' in ii]
if len(endswith1) > 0:
for index, inn in enumerate(n.input):
if inn in endswith1:
new_input = inn[:-2]
n.input.remove(inn)
n.input.insert(index, new_input)
In [7]:
nodes = {}
for node in gd.node:
nodes[node.name] = node
In [8]:
for n in nodes_with_switch_op:
n.op = 'Identity'
n.input.pop()
In [9]:
nodes_with_switch_op = [x for x in gd.node if x.op.lower().find('switch') != -1]
In [10]:
[(n.name, [i for i in n.input]) for n in gd.node if n.name.lower().find('keras_learning_phase') != -1]
Out[10]:
In [11]:
nodes['batch_normalization_95/keras_learning_phase']
Out[11]:
In [12]:
nodes['batch_normalization_95/keras_learning_phase'].op = 'Const'
In [13]:
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
nodes['batch_normalization_95/keras_learning_phase'].attr.get_or_create('value').CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto([False], dtypes.bool, [1])))
In [14]:
del nodes['batch_normalization_95/keras_learning_phase'].attr['shape']
In [15]:
nodes['batch_normalization_95/keras_learning_phase']
Out[15]:
In [16]:
from tensorflow.python.framework import graph_util
output_graph_def = graph_util.convert_variables_to_constants(sess, gd, pred_node_names)
with tf.gfile.GFile('../data/dog_breed_graph_v2.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
In [ ]: