In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlowとXLAライブラリをインポートします。XLAには、一部または全てのモデルを XLA でコンパイルする実験的なAPIである xla.compile() が含まれています。
In [ ]:
import tensorflow as tf
from tensorflow.contrib.compiler import xla
必要ないくつかの定数を定義し、 MNISTのデータセットを用意します。
In [ ]:
# それぞれの入力イメージの大きさは、 28 x 28ピクセル
IMAGE_SIZE = 28 * 28
# 個別の数字のラベル [0..9] の個数
NUM_CLASSES = 10
# それぞれのトレーニングバッチ(ステップ)での標本数
TRAIN_BATCH_SIZE = 100
# トレーニングステップを実行する回数
TRAIN_STEPS = 1000
In [4]:
# MNISTデータセットをロードする。
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)
iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
In [ ]:
def build_mnist_model(x, y_):
y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
return y, train_step
In [6]:
[y] = xla.compile(build_mnist_model, inputs=[images, labels])
グラフをコンパイルするとき、XLAはターゲット関数によって構築されたグラフの全てのノードを、いくつかのXLAのオペレータで置き換えます。
xla.compileは、生成されたXLAのオペレータから独立して実行できる tf.Operation を返しません
代わりに、ターゲット関数から返された tf.Operation ノードは、返された全ての tf.Tensor の値との制御依存関係として追加されます。これにより、 返されたテンソルが評価されるときに、 tf.Operation ノードの実行をトリガします。
擬似コードによるxla.compileの実装は、以下のようになります:
# TensorFlowに、XLAが扱いやすい方法でコードを実行するよう依頼する
y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
y = tf.identity(y)
# TensorFlowに、XLAが扱いやすい方法でコードの実行を停止するよう依頼する
xla.compile()は常に tf.Tensor のリスト(1要素しか無かったとしても)を返します。
もしあなたが構築したグラフを今表示したら、通常のTensorFlowのグラフとそれほど変わらないことがわかり、前に述べたXLAのオペレータを見つけることができないでしょう。これは、あなたが sess.run() でグラフを実行しようとしても、実際のコンパイルは後ほど発生するからです。後ほど、TensorFlowは実際にXLAオペレータを生成する一連のグラフ書き換えパスをトリガーします。これは、すべての入力がそろったときに、計算をコンパイルして実行します。
In [ ]:
# セッションを作成しすべての変数を初期化。
# xla.compile()は、Keras model.fit() APIやTF eager modeとはまだ動作しません。
sess = tf.Session()
sess.run(tf.global_variables_initializer())
以下のコードブロックはモデルを学習します。 y の評価は、制御依存関係がある train_step をトリガします。これは、モデル変数を更新します。
In [8]:
# 学習用データセットを与える
sess.run(iterator.make_initializer(train_ds))
# TRAIN_STEPS ステップだけ実行する
for i in range(TRAIN_STEPS):
sess.run(y)
print("Model trained for %s steps." % TRAIN_STEPS)
In [9]:
# 学習済みモデルをテストする
# テスト用データセットを与える
sess.run(iterator.make_initializer(test_ds))
# 精度を計算する
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))
In [ ]:
# セッションを片付ける
sess.close()