In [1]:
import tensorflow as tf
from os import path, makedirs

In [2]:
# Create directory for saving.
logs_path = "tensor_log/" + path.splitext(path.basename("tensor_saving"))[0]
checkpoint_dir = path.abspath(path.join(logs_path, "checkpoints"))
checkpoint_prefix = path.join(checkpoint_dir, "model")
if not path.exists(checkpoint_dir):
    makedirs(checkpoint_dir)

In [3]:
# Two types of vars
v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.Variable(tf.zeros([3]), name="v2")

# Var initializer
init_op = tf.global_variables_initializer()

In [4]:
# Two operations to track change
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

In [5]:
saver = tf.train.Saver()
# A list of vars can be passed as well
# saver = tf.train.Saver(tf.global_variables())
print("All Variables:\n", tf.global_variables())


All Variables:
 [<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(3,) dtype=float32_ref>]

In [6]:
# Start Session and init vars
sess = tf.Session()
sess.run(init_op)

In [7]:
# Print Init values
print("v1 Init: ", v1.eval(session=sess))
print("v2 Init: ", v2.eval(session=sess))


v1 Init:
 [ 0.  0.  0.]
v2 Init:
 [ 0.  0.  0.]

In [8]:
# Run operations and save results
inc_v1.op.run(session=sess)
dec_v2.op.run(session=sess)

print("v1 Post: ", v1.eval(session=sess))
print("v2 Post: ", v2.eval(session=sess))

path = saver.save(sess, checkpoint_prefix)


v1 Post:  [ 1.  1.  1.]
v2 Post:  [-1. -1. -1.]

In [12]:
# Close session
sess.close()
tf.reset_default_graph()

In [13]:
# Reset session
v1_new = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2_new = tf.Variable(tf.zeros([3]), name="v2")

init_op = tf.global_variables_initializer()

saver = tf.train.Saver()

sess = tf.Session()
sess.run(init_op)

In [15]:
# Show session loading
print("Initial Parameters")
print("v1 Init: ", v1_new.eval(session=sess))
print("v2 Init: ", v2_new.eval(session=sess))

saver.restore(sess, checkpoint_prefix)

print("Restored Parameters")
print("v1 Loaded: ", v1_new.eval(session=sess))
print("v2 Loaded: ", v2_new.eval(session=sess))


Initial Parameters
v1 Init:  [ 1.  1.  1.]
v2 Init:  [-1. -1. -1.]
INFO:tensorflow:Restoring parameters from /Users/zgoldstein/projects/ai/tensor_tutorial/tensor_log/tensor_saving/checkpoints/model
Restored Parameters
v1 Loaded:  [ 1.  1.  1.]
v2 Loaded:  [-1. -1. -1.]

In [ ]: