[MSE-01] モジュールをインポートして、乱数のシードを設定します。


In [1]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

np.random.seed(20160604)


/Users/tetsu/.pyenv/versions/miniconda2-latest/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

[MSE-02] MNISTのデータセットを用意します。


In [2]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

[MSE-03] ソフトマックス関数による確率 p の計算式を用意します。


In [3]:
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
w0 = tf.Variable(tf.zeros([10]))
f = tf.matmul(x, w) + w0
p = tf.nn.softmax(f)

[MSE-04] 誤差関数 loss とトレーニングアルゴリズム train_step を用意します。


In [4]:
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(p))
train_step = tf.train.AdamOptimizer().minimize(loss)

[MSE-05] 正解率 accuracy を定義します。


In [5]:
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

[MSE-06] セッションを用意して、Variableを初期化します。


In [6]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

[MSE-07] パラメーターの最適化を2000回繰り返します。

1回の処理において、トレーニングセットから取り出した100個のデータを用いて、勾配降下法を適用します。

最終的に、テストセットに対して約92%の正解率が得られます。


In [7]:
i = 0
for _ in range(2000):
    i += 1
    batch_xs, batch_ts = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts})
    if i % 100 == 0:
        loss_val, acc_val = sess.run([loss, accuracy],
            feed_dict={x:mnist.test.images, t: mnist.test.labels})
        print ('Step: %d, Loss: %f, Accuracy: %f'
               % (i, loss_val, acc_val))


Step: 100, Loss: 7516.912109, Accuracy: 0.845200
Step: 200, Loss: 5332.066406, Accuracy: 0.876400
Step: 300, Loss: 4517.929199, Accuracy: 0.891200
Step: 400, Loss: 4050.018799, Accuracy: 0.898900
Step: 500, Loss: 3777.578613, Accuracy: 0.904200
Step: 600, Loss: 3594.546143, Accuracy: 0.905700
Step: 700, Loss: 3446.362793, Accuracy: 0.910400
Step: 800, Loss: 3347.382812, Accuracy: 0.910900
Step: 900, Loss: 3245.309082, Accuracy: 0.914200
Step: 1000, Loss: 3175.688232, Accuracy: 0.915500
Step: 1100, Loss: 3105.357178, Accuracy: 0.915100
Step: 1200, Loss: 3072.926025, Accuracy: 0.916500
Step: 1300, Loss: 3036.589844, Accuracy: 0.916000
Step: 1400, Loss: 3004.714355, Accuracy: 0.917100
Step: 1500, Loss: 2957.116943, Accuracy: 0.917800
Step: 1600, Loss: 2907.521973, Accuracy: 0.918500
Step: 1700, Loss: 2908.343750, Accuracy: 0.919500
Step: 1800, Loss: 2901.087402, Accuracy: 0.919400
Step: 1900, Loss: 2868.233398, Accuracy: 0.919800
Step: 2000, Loss: 2866.670166, Accuracy: 0.920800

[MSE-08] この時点のパラメーターを用いて、テストセットに対する予測を表示します。

ここでは、「0」〜「9」の数字に対して、正解と不正解の例を3個ずつ表示します。


In [8]:
images, labels = mnist.test.images, mnist.test.labels
p_val = sess.run(p, feed_dict={x:images, t: labels}) 

fig = plt.figure(figsize=(8,15))
for i in range(10):
    c = 1
    for (image, label, pred) in zip(images, labels, p_val):
        prediction, actual = np.argmax(pred), np.argmax(label)
        if prediction != i:
            continue
        if (c < 4 and i == actual) or (c >= 4 and i != actual):
            subplot = fig.add_subplot(10,6,i*6+c)
            subplot.set_xticks([])
            subplot.set_yticks([])
            subplot.set_title('%d / %d' % (prediction, actual))
            subplot.imshow(image.reshape((28,28)), vmin=0, vmax=1,
                           cmap=plt.cm.gray_r, interpolation="nearest")
            c += 1
            if c > 6:
                break