Lab: tf.train.Saver


In [ ]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

print(tf.__version__)

まずは checkpoint を保存するディレクトリを作成しましょう。


In [ ]:
CHECKPOINT_DIR = "saver_sample"

if not os.path.isdir("saver_sample"):
    os.mkdir("saver_sample")

tf.Variable にそれぞれ "v1", "v2" という名前を付けて保存します。


In [ ]:
with tf.Graph().as_default() as g1:
    v1 = tf.Variable(1., name="v1")
    v2 = tf.Variable(2., name="v2")
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    sess.run(init_op)
    saver.save(sess, os.path.join(CHECKPOINT_DIR, "sample.ckpt"))

sample.ckpt というファイルが生成されていれば無事保存されています。 この中に {"v1": 1, "v2": 2} のような key/value の形式 (実際にこのような JSON で保存されているわけではありません) で名前と値が保存されていると思ってください。

Checkpoint を読み込むときは、以下のように同じ名前の tf.Variable を作ってやれば良いです。


In [ ]:
with tf.Graph().as_default() as g2:
    v1 = tf.Variable(3., name="v1")
    v2 = tf.Variable(4., name="v2")
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

with tf.Session(graph=g2) as sess:
    sess.run(init_op)
    print(sess.run([v1, v2]))
    saver.restore(sess, os.path.join(CHECKPOINT_DIR, "sample.ckpt"))
    print(sess.run([v1, v2]))

saver.restore を実行する前は [3, 4] という値が入っていますが、 checkpoint を読み込んだ後は保存した通り [1, 2] に値が更新されています。