In [1]:
import tensorflow as tf
from os import path, makedirs
In [2]:
# Create directory for saving.
logs_path = "tensor_log/" + path.splitext(path.basename("tensor_saving"))[0]
checkpoint_dir = path.abspath(path.join(logs_path, "checkpoints"))
checkpoint_prefix = path.join(checkpoint_dir, "model")
if not path.exists(checkpoint_dir):
makedirs(checkpoint_dir)
In [3]:
# Two types of vars
v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.Variable(tf.zeros([3]), name="v2")
# Var initializer
init_op = tf.global_variables_initializer()
In [4]:
# Two operations to track change
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
In [5]:
saver = tf.train.Saver()
# A list of vars can be passed as well
# saver = tf.train.Saver(tf.global_variables())
print("All Variables:\n", tf.global_variables())
In [6]:
# Start Session and init vars
sess = tf.Session()
sess.run(init_op)
In [7]:
# Print Init values
print("v1 Init: ", v1.eval(session=sess))
print("v2 Init: ", v2.eval(session=sess))
In [8]:
# Run operations and save results
inc_v1.op.run(session=sess)
dec_v2.op.run(session=sess)
print("v1 Post: ", v1.eval(session=sess))
print("v2 Post: ", v2.eval(session=sess))
path = saver.save(sess, checkpoint_prefix)
In [12]:
# Close session
sess.close()
tf.reset_default_graph()
In [13]:
# Reset session
v1_new = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2_new = tf.Variable(tf.zeros([3]), name="v2")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init_op)
In [15]:
# Show session loading
print("Initial Parameters")
print("v1 Init: ", v1_new.eval(session=sess))
print("v2 Init: ", v2_new.eval(session=sess))
saver.restore(sess, checkpoint_prefix)
print("Restored Parameters")
print("v1 Loaded: ", v1_new.eval(session=sess))
print("v2 Loaded: ", v2_new.eval(session=sess))
In [ ]: