In [3]:
# %load /Users/facai/Study/book_notes/preconfig.py
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set(font='SimHei', font_scale=2.5)
plt.rcParams['axes.grid'] = False
import tensorflow as tf
def show_image(filename, figsize=None, res_dir=True):
if figsize:
plt.figure(figsize=figsize)
if res_dir:
filename = './res/{}'.format(filename)
plt.imshow(plt.imread(filename))
参考: https://www.tensorflow.org/programmers_guide/graphs
tf.Graph
op, tensor
variable
name_scope, variable_scop, collection
save and restore
In [6]:
a = tf.constant(1)
b = a * 2
b
Out[6]:
In [7]:
b.op
Out[7]:
In [11]:
b.consumers()
Out[11]:
In [15]:
a.op
Out[15]:
In [19]:
a.consumers()
Out[19]:
tensorflow/python/framework/ops.py
__add__
In [8]:
b.op.outputs
Out[8]:
In [9]:
list(b.op.inputs)
Out[9]:
In [14]:
print(b.op.inputs[0])
print(a)
In [17]:
list(a.op.inputs)
Out[17]:
Operator和Tensor构成无向图
# run
sess.run([b])
参考:
In [20]:
v = tf.Variable([0])
c = b + v
c
Out[20]:
In [23]:
list(c.op.inputs)
Out[23]:
In [25]:
c.op.inputs[1].op
Out[25]:
In [26]:
list(c.op.inputs[1].op.inputs)
Out[26]:
In [21]:
v
Out[21]:
实际上,对变量的读是通过tf.identity
算子得到:
c = tf.add(b, tf.identity(v))
参考:https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable
class Layer:
def build(self):
pass
def call(self, inputs):
pass
参考:https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard
In [58]:
graph_a = tf.Graph()
with graph_a.as_default():
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
print(v1)
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
inc_v1.op.run()
save_path = saver.save(sess, "./tmp/model.ckpt", write_meta_graph=True)
print("Model saved in path: %s" % save_path)
pb_path = tf.train.write_graph(graph_a.as_graph_def(), "./tmp/", "graph.pbtxt", as_text=True)
print("Graph saved in path: %s" % pb_path)
graph.pbtxt部份示意:v1 + 1
:
node {
name: "add"
op: "Add"
input: "v1/read"
input: "add/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
In [62]:
graph_b = tf.Graph()
with graph_b.as_default():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./tmp/model.ckpt.meta')
saver.restore(sess, "./tmp/model.ckpt")
print(graph_b.get_operations())
v1 = graph_b.get_tensor_by_name("v1:0")
print("------------------")
print("v1 : %s" % v1.eval(session=sess))
总结:
tf.train.Saver
会保存GraphDef和Variable信息,用它可以直接恢复图。tf.train.write_graph
、tf.GraphDef
和tf.import_graph_def
,主要用于固化模型(只有GraphDef信息)。参考:
In [ ]: