[MSL-01] 必要なモジュールをインポートして、乱数のシードを設定します。
In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
np.random.seed(20160612)
tf.set_random_seed(20160612)
[MSL-02] MNISTのデータセットを用意します。
In [2]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
[MSL-03] 単層ニューラルネットワークを用いた確率 p の計算式を用意します。
In [3]:
num_units = 1024
x = tf.placeholder(tf.float32, [None, 784])
w1 = tf.Variable(tf.truncated_normal([784, num_units]))
b1 = tf.Variable(tf.zeros([num_units]))
hidden1 = tf.nn.relu(tf.matmul(x, w1) + b1)
w0 = tf.Variable(tf.zeros([num_units, 10]))
b0 = tf.Variable(tf.zeros([10]))
p = tf.nn.softmax(tf.matmul(hidden1, w0) + b0)
[MSL-04] 誤差関数 loss、トレーニングアルゴリズム train_step、正解率 accuracy を定義します。
In [4]:
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(p))
train_step = tf.train.AdamOptimizer().minimize(loss)
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
[MSL-05] セッションを用意して、Variableを初期化します。
In [5]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
[MSL-06] パラメーターの最適化を2000回繰り返します。
1回の処理において、トレーニングセットから取り出した100個のデータを用いて、勾配降下法を適用します。
最終的に、テストセットに対して約97%の正解率が得られます。
In [6]:
i = 0
for _ in range(2000):
i += 1
batch_xs, batch_ts = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts})
if i % 100 == 0:
loss_val, acc_val = sess.run([loss, accuracy],
feed_dict={x:mnist.test.images, t: mnist.test.labels})
print ('Step: %d, Loss: %f, Accuracy: %f'
% (i, loss_val, acc_val))
[MSL-07] 最適化されたパラメーターを用いて、テストセットに対する予測を表示します。
ここでは、「0」〜「9」の数字に対して、正解と不正解の例を3個ずつ表示します。
In [7]:
images, labels = mnist.test.images, mnist.test.labels
p_val = sess.run(p, feed_dict={x:images, t: labels})
fig = plt.figure(figsize=(8,15))
for i in range(10):
c = 1
for (image, label, pred) in zip(images, labels, p_val):
prediction, actual = np.argmax(pred), np.argmax(label)
if prediction != i:
continue
if (c < 4 and i == actual) or (c >= 4 and i != actual):
subplot = fig.add_subplot(10,6,i*6+c)
subplot.set_xticks([])
subplot.set_yticks([])
subplot.set_title('%d / %d' % (prediction, actual))
subplot.imshow(image.reshape((28,28)), vmin=0, vmax=1,
cmap=plt.cm.gray_r, interpolation="nearest")
c += 1
if c > 6:
break