Getting Started with TensorFlow

何はともあれ TensorFlow を始めてみましょう!


In [ ]:
import tensorflow as tf
import numpy as np

print(tf.__version__)

Hello TensorFlow

Python を使って足し算をしてみましょう(決して馬鹿にしているわけではなく大真面目です)!


In [ ]:
a = 1.
b = 2.
c = a + b

print(c)

当然ですが 3.0 と答えが表示されます。

今度は TensorFlow で同じような足し算をやってみましょう。


In [ ]:
a = tf.constant(1.)
b = tf.constant(2.)
c = tf.add(a, b)

print(c)

Tensor というオブジェクトが表示されますね。 実はまだこの時点ではデータフローグラフが作成されただけで、足し算は行われていません。 足し算を実行するにはセッションを介して実行する必要があります。


In [ ]:
with tf.Session() as sess:
    result = sess.run(c)

print(result)

行列演算

NumPy と TensorFlow で行列やベクトルの計算方法を比較してみましょう。

NumPy


In [ ]:
a = np.array([5, 3, 8])
b = np.array([3, -1, 2])
c = np.add(a, b)

print(c)

TensorFlow


In [ ]:
a = tf.constant([5, 3, 8])
b = tf.constant([3, -1, 2])
c = tf.add(a, b)

print(c)

In [ ]:
with tf.Session() as sess:
    result = sess.run(c)

print(result)

TensorFlow + placeholder

計算をするたびにデータフローグラフを作っていると、データフローグラフがどんどん肥大化してしまいます。 そのため、 TensorFlow には一部の値を差し替えつつデータフローグラフの共通部分を再利用するための placeholder という仕組みが存在しています。


In [ ]:
a = tf.placeholder(dtype=tf.int32, shape=(None,))
b = tf.placeholder(dtype=tf.int32, shape=(None,))
c = tf.add(a, b)

with tf.Session() as sess:
    result1 = sess.run(c, feed_dict={a: [3, 4, 5], b: [-1, 2, 3]})
    result2 = sess.run(c, feed_dict={a: [1, 2, 3], b: [3, 2, 1]})

print(result1)
print(result2)

典型的な使い方として、学習時に使うデータを placeholder で定義しておくという方法があります。 placeholder に対して少しずつ学習用データを流していくというのが TensorFlow で機械学習のアルゴリズムを実装するときの定石です。

tf.Variable

TensorFlow では、計算に使われた値は基本的に捨てられていきます。 tf.add の結果はセッションを介して実行するたびに計算し直すことになりますし、 tf.constant の値はメモリ上に保持されず必要に応じて再生成されます。 後から tf.constant の値を書き換えることもできません。

ただし tf.Variable だけが例外で、値をメモリ上に保持し続けて後から書き換えることも可能になっています。 ニューラルネットやその他多くの機械学習手法で weight として使うことが想定されています。


In [ ]:
v = tf.Variable([1, 2])
assign_op = tf.assign(v, [2, 3])
init_op = tf.global_variables_initializer()

tf.Variable 最初に初期化処理を行う必要があるので init_op を実行します。 その後 v の値を表示すると assign_op を実行する前後で値が書き換えられていることが確認できます。


In [ ]:
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(v))
    sess.run(assign_v)
    print(sess.run(v))

tf.Variable に何の値が入っているかという情報はセッションと紐付けられています。 計算グラフは tf.Variable を計算にどう使うかという手順の情報だけを持っており、具体的な値の情報はセッションが持っているということを覚えておくと、コードを書く時に混乱せずに済むでしょう。