[CNN-01] 必要なモジュールをインポートして、乱数のシードを設定します。


In [1]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

np.random.seed(20160704)
tf.set_random_seed(20160704)


/Users/tetsu/.pyenv/versions/miniconda2-latest/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

[CNN-02] MNISTのデータセットを用意します。


In [2]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

[CNN-03] 1段目の畳み込みフィルターとプーリング層を定義します。


In [3]:
num_filters1 = 32

x = tf.placeholder(tf.float32, [None, 784])
x_image = tf.reshape(x, [-1,28,28,1])

W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],
                                          stddev=0.1))
h_conv1 = tf.nn.conv2d(x_image, W_conv1,
                       strides=[1,1,1,1], padding='SAME')

b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))
h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)

h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],
                         strides=[1,2,2,1], padding='SAME')

[CNN-04] 2段目の畳み込みフィルターとプーリング層を定義します。


In [4]:
num_filters2 = 64

W_conv2 = tf.Variable(
            tf.truncated_normal([5,5,num_filters1,num_filters2],
                                stddev=0.1))
h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,
                       strides=[1,1,1,1], padding='SAME')

b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))
h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)

h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],
                         strides=[1,2,2,1], padding='SAME')

[CNN-05] 全結合層、ドロップアウト層、ソフトマックス関数を定義します。


In [5]:
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])

num_units1 = 7*7*num_filters2
num_units2 = 1024

w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))
b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))
hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)

keep_prob = tf.placeholder(tf.float32)
hidden2_drop = tf.nn.dropout(hidden2, keep_prob)

w0 = tf.Variable(tf.zeros([num_units2, 10]))
b0 = tf.Variable(tf.zeros([10]))
p = tf.nn.softmax(tf.matmul(hidden2_drop, w0) + b0)

[CNN-06] 誤差関数 loss、トレーニングアルゴリズム train_step、正解率 accuracy を定義します。


In [6]:
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(p))
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

[CNN-07] セッションを用意して、Variable を初期化します。


In [7]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

[CNN-08] パラメーターの最適化を20000回繰り返します。

最終的に、テストセットに対して約99%の正解率が得られます。


In [8]:
i = 0
for _ in range(20000):
    i += 1
    batch_xs, batch_ts = mnist.train.next_batch(50)
    sess.run(train_step,
             feed_dict={x:batch_xs, t:batch_ts, keep_prob:0.5})
    if i % 500 == 0:
        loss_vals, acc_vals = [], []
        for c in range(4):
            start = int(len(mnist.test.labels) / 4 * c)
            end = int(len(mnist.test.labels) / 4 * (c+1))
            loss_val, acc_val = sess.run([loss, accuracy],
                feed_dict={x:mnist.test.images[start:end],
                           t:mnist.test.labels[start:end],
                           keep_prob:1.0})
            loss_vals.append(loss_val)
            acc_vals.append(acc_val)
        loss_val = np.sum(loss_vals)
        acc_val = np.mean(acc_vals)
        print ('Step: %d, Loss: %f, Accuracy: %f'
               % (i, loss_val, acc_val))
        saver.save(sess, './cnn_session', global_step=i)


Step: 500, Loss: 1517.676758, Accuracy: 0.952400
Step: 1000, Loss: 938.293213, Accuracy: 0.972000
Step: 1500, Loss: 747.728638, Accuracy: 0.976300
Step: 2000, Loss: 621.786865, Accuracy: 0.979600
Step: 2500, Loss: 586.465210, Accuracy: 0.980500
Step: 3000, Loss: 529.665466, Accuracy: 0.982200
Step: 3500, Loss: 462.990417, Accuracy: 0.983800
Step: 4000, Loss: 473.560028, Accuracy: 0.984100
Step: 4500, Loss: 409.029968, Accuracy: 0.984700
Step: 5000, Loss: 398.725830, Accuracy: 0.986200
Step: 5500, Loss: 419.560913, Accuracy: 0.986100
Step: 6000, Loss: 383.770203, Accuracy: 0.986900
Step: 6500, Loss: 347.540405, Accuracy: 0.987800
Step: 7000, Loss: 365.736145, Accuracy: 0.987300
Step: 7500, Loss: 338.154022, Accuracy: 0.988300
Step: 8000, Loss: 364.098969, Accuracy: 0.988000
Step: 8500, Loss: 327.820312, Accuracy: 0.988300
Step: 9000, Loss: 321.452393, Accuracy: 0.989300
Step: 9500, Loss: 305.152161, Accuracy: 0.989700
Step: 10000, Loss: 321.682190, Accuracy: 0.989000
Step: 10500, Loss: 319.188110, Accuracy: 0.989000
Step: 11000, Loss: 312.048767, Accuracy: 0.989200
Step: 11500, Loss: 287.300629, Accuracy: 0.990800
Step: 12000, Loss: 303.551636, Accuracy: 0.990000
Step: 12500, Loss: 302.909607, Accuracy: 0.989800
Step: 13000, Loss: 311.328461, Accuracy: 0.989400
Step: 13500, Loss: 294.923615, Accuracy: 0.990100
Step: 14000, Loss: 287.322632, Accuracy: 0.990200
Step: 14500, Loss: 288.289429, Accuracy: 0.990000
Step: 15000, Loss: 264.954895, Accuracy: 0.991300
Step: 15500, Loss: 273.416687, Accuracy: 0.992100
Step: 16000, Loss: 299.554626, Accuracy: 0.990300
Step: 16500, Loss: 292.068726, Accuracy: 0.990200
Step: 17000, Loss: 275.824036, Accuracy: 0.991000
Step: 17500, Loss: 247.039978, Accuracy: 0.992300
Step: 18000, Loss: 272.884399, Accuracy: 0.991100
Step: 18500, Loss: 274.981110, Accuracy: 0.990800
Step: 19000, Loss: 268.536926, Accuracy: 0.991200
Step: 19500, Loss: 254.943939, Accuracy: 0.991700
Step: 20000, Loss: 245.394318, Accuracy: 0.992000

[CNN-09] セッション情報を保存したファイルが生成されていることを確認します。


In [9]:
!ls cnn_session*


cnn_session-18000.data-00000-of-00001  cnn_session-19000.meta
cnn_session-18000.index		       cnn_session-19500.data-00000-of-00001
cnn_session-18000.meta		       cnn_session-19500.index
cnn_session-18500.data-00000-of-00001  cnn_session-19500.meta
cnn_session-18500.index		       cnn_session-20000.data-00000-of-00001
cnn_session-18500.meta		       cnn_session-20000.index
cnn_session-19000.data-00000-of-00001  cnn_session-20000.meta
cnn_session-19000.index