load checkpoint model


In [17]:
import tensorflow as tf
import numpy as np


graph = tf.Graph()
with graph.as_default():
    x_t = tf.placeholder(tf.float32, [], name='x')
    tf.add_to_collection('x_t', x_t)
    tf.add_to_collection('y_t', y_t)
    W_t = tf.Variable(2.5, name='W')
    y_t = x_t * W_t
    W_new_value_t = tf.placeholder(tf.float32, [])
    W_assign_t = tf.assign(W_t, W_new_value_t)
    saver_t = tf.train.Saver()
    sess =  tf.Session()
    with sess.as_default():
        sess.run(tf.global_variables_initializer())
        sess.run(W_assign_t, feed_dict={W_new_value_t: 3.7})
        saver_t.save(sess, '/tmp/foo1')
        for x in [2.34, 5.67]:
            y = sess.run(y_t, feed_dict={x_t: x})
            print('y', y)
#         tf.train.

    
graph = tf.Graph()
with graph.as_default():
    saver_t = tf.train.import_meta_graph('/tmp/foo1.meta')
    sess =  tf.Session()
    with sess.as_default():
        pass
        saver_t.restore(sess, '/tmp/foo1')
        print(tf.global_variables()[0])
        x_t = tf.get_collection('x_t')[0]
        y_t = tf.get_collection('y_t')[0]
        for x in [2.34, 5.67]:
            y = sess.run(y_t, feed_dict={x_t: x})
            print('y', y)

        #         print(tf.global_variables()[1])
#         saver_t.restore(sess, '/tmp/foo1.ckpt')


y 8.658
y 20.979
Tensor("W/read:0", shape=(), dtype=float32)
y 8.658
y 20.979