In [ ]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
print(tf.__version__)
TensorFlow 付属のモジュールを使って MNIST データセットをダウンロードします。
In [ ]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
ニューラルネットの入力となる Tensor を tf.placeholder で用意します。 学習の際にランダムサンプリングしたデータを使って weight の更新を行うので、後から使うデータを変更できるように tf.placeholder を使います。
また、データをいくつずつ渡していくかも後で自由に決められるように、 Tensor の shape を [None, 784] と指定しています。
TensorFlow では tf.placeholder の shape を指定する時に、不明な場合は None とすることが可能です。
ただし、一部のオペレーションは Tensor の shape がきちんと定義されていないと実行できない場合があるので注意してください。
In [ ]:
x_ph = tf.placeholder(tf.float32, [None, 784])
y_ph = tf.placeholder(tf.float32, [None, 10])
以下が tf.layers を使ってニューラルネットのノードや辺にあたる部分を作成するコードです。
tf.layers.dense は一般的な全結合層を追加する関数です。
In [ ]:
hidden = tf.layers.dense(x_ph, 20)
logits = tf.layers.dense(hidden, 10)
y = tf.nn.softmax(logits)
損失関数として cross entropy を定義します。
In [ ]:
cross_entropy = -tf.reduce_mean(y_ph * tf.log(y))
学習に直接必要な部分ではありませんが、正答率を計算するためのオペレーションを用意します。
In [ ]:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_ph, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
最小化してほしい cross_entropy を渡して、勾配法で tf.Variable を更新してくれるオペレーションを作成します。
tf.layers を使う場合は tf.Variable の存在が隠蔽されていますが、裏ではニューラルネットの辺にあたる部分 (weight) を tf.Variable として作成して計算グラフに追加しています。
In [ ]:
train_op = tf.train.GradientDescentOptimizer(1e-1).minimize(cross_entropy)
tf.Variable を初期化するオペレーションを作成します。
In [ ]:
init_op = tf.global_variables_initializer()
計算グラフを構築し終えたら、あとはオペレーション (ノード) を選んで実行するだけです。
ランダムサンプリングしたデータを tf.placeholder に渡しつつ、繰り返し train_op を実行します。
In [ ]:
with tf.Session() as sess:
sess.run(init_op)
for i in range(3001):
x_train, y_train = mnist.train.next_batch(100)
sess.run(train_op, feed_dict={x_ph: x_train, y_ph: y_train})
if i % 100 == 0:
train_loss = sess.run(cross_entropy, feed_dict={x_ph: x_train, y_ph: y_train})
test_loss = sess.run(cross_entropy, feed_dict={x_ph: mnist.test.images, y_ph: mnist.test.labels})
tf.logging.info("Iteration: {0} Training Loss: {1} Test Loss: {2}".format(i, train_loss, test_loss))
test_accuracy = sess.run(accuracy, feed_dict={x_ph: mnist.test.images, y_ph: mnist.test.labels})
tf.logging.info("Accuracy: {}".format(test_accuracy))