In [ ]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
print(tf.__version__)
まずは checkpoint を保存するディレクトリを作成しましょう。
In [ ]:
CHECKPOINT_DIR = "saver_sample"
if not os.path.isdir("saver_sample"):
os.mkdir("saver_sample")
tf.Variable にそれぞれ "v1", "v2" という名前を付けて保存します。
In [ ]:
with tf.Graph().as_default() as g1:
v1 = tf.Variable(1., name="v1")
v2 = tf.Variable(2., name="v2")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session(graph=g1) as sess:
sess.run(init_op)
saver.save(sess, os.path.join(CHECKPOINT_DIR, "sample.ckpt"))
sample.ckpt というファイルが生成されていれば無事保存されています。
この中に {"v1": 1, "v2": 2} のような key/value の形式 (実際にこのような JSON で保存されているわけではありません) で名前と値が保存されていると思ってください。
Checkpoint を読み込むときは、以下のように同じ名前の tf.Variable を作ってやれば良いです。
In [ ]:
with tf.Graph().as_default() as g2:
v1 = tf.Variable(3., name="v1")
v2 = tf.Variable(4., name="v2")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session(graph=g2) as sess:
sess.run(init_op)
print(sess.run([v1, v2]))
saver.restore(sess, os.path.join(CHECKPOINT_DIR, "sample.ckpt"))
print(sess.run([v1, v2]))
saver.restore を実行する前は [3, 4] という値が入っていますが、 checkpoint を読み込んだ後は保存した通り [1, 2] に値が更新されています。