In [17]:
import tensorflow as tf
import numpy as np
graph = tf.Graph()
with graph.as_default():
x_t = tf.placeholder(tf.float32, [], name='x')
tf.add_to_collection('x_t', x_t)
tf.add_to_collection('y_t', y_t)
W_t = tf.Variable(2.5, name='W')
y_t = x_t * W_t
W_new_value_t = tf.placeholder(tf.float32, [])
W_assign_t = tf.assign(W_t, W_new_value_t)
saver_t = tf.train.Saver()
sess = tf.Session()
with sess.as_default():
sess.run(tf.global_variables_initializer())
sess.run(W_assign_t, feed_dict={W_new_value_t: 3.7})
saver_t.save(sess, '/tmp/foo1')
for x in [2.34, 5.67]:
y = sess.run(y_t, feed_dict={x_t: x})
print('y', y)
# tf.train.
graph = tf.Graph()
with graph.as_default():
saver_t = tf.train.import_meta_graph('/tmp/foo1.meta')
sess = tf.Session()
with sess.as_default():
pass
saver_t.restore(sess, '/tmp/foo1')
print(tf.global_variables()[0])
x_t = tf.get_collection('x_t')[0]
y_t = tf.get_collection('y_t')[0]
for x in [2.34, 5.67]:
y = sess.run(y_t, feed_dict={x_t: x})
print('y', y)
# print(tf.global_variables()[1])
# saver_t.restore(sess, '/tmp/foo1.ckpt')