In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

TensorFlow 2.0 では Eager Execution が既定で有効になっています。ユーザーインターフェイスは直感的で柔軟です(演算を一度だけ行う場合にはずっと簡単に、かつ迅速に実行されます)。しかしながら、それは性能と展開の面での犠牲の上に成り立っています。

最高性能を得ながら、モデルをどこへでも展開できるようにするには、tf.function を使ってプログラムから計算グラフを作成します。 AutoGraph のおかげで、驚くほど多くの Python コードが tf.function でそのまま動作しますが、気をつけなければならない落とし穴も存在します。

ポイントと推奨事項は下記の通りです。

  • オブジェクトの変更やリストへの追加のような Python の副作用に依存しないこと
  • tf.functions は NumPy の演算や Python の組み込み演算よりも、TensorFlow の演算に適していること
  • 迷ったときは、for x in y というイディオムを使うこと

In [ ]:
import tensorflow as tf

In [ ]:
import contextlib

# 遭遇するかもしれないいくつかのエラーをデモするためのヘルパー関数
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}: {}'.format(error_class, e))
  except Exception as e:
    print('Got unexpected exception \n  {}: {}'.format(type(e), e))
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

あなたが定義した tf.function は TensorFlow Core の演算に似たものです。例えばそれを即時に実行することも、計算グラフで使うこともできますし、勾配を計算することも可能です。


In [ ]:
# function は演算のように振る舞う

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

In [ ]:
# function は勾配を計算できる

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

In [ ]:
# function 内で function を使うこともできる

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

トレーシングとポリモーフィズム

Python の動的型付けは、関数をさまざまな型の引数で呼び出すことができ、Python がそれぞれのシナリオで異なる動作をするということを意味します。

他方で、TensorFlow の計算グラフでは、dtype と shape の次元が静的であることが必要です。tf.function は、正しい計算グラフを生成するために必要なときには関数を再トレースして、このギャップをつなぐ役割を果たします。

異なる型の引数を使って関数を呼び出し、何が起きるか見てみましょう。


In [ ]:
# Function はポリモーフィック

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

トレースの動作を制御するためには、下記のようなテクニックを使います。

  • 新しい tf.function を作成する。別々の tf.function オブジェクトがトレースを共有することはない。
  • 特定のトレースを得るには get_concrete_function メソッドを使用する。
  • 計算グラフの呼び出し時に1回だけトレースを行うには、 input_signature を指定して tf.function を呼び出す。

In [ ]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))

In [ ]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# 1次元のテンソルを input signature として指定しているので、これは失敗する
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

いつ再トレースするのか?

ポリモーフィックな tf.function はトレーシングによって生成された具象関数のキャッシュを保持しています。キャッシュのキーは、実際にはその関数の引数及びキーワード引数から生成されたキーのタプルです。tf.Tensor 引数から生成されるキーは、テンソルの shape と型です。Python の組み込み型引数から生成されるキーはその値です。それ以外の Python の型では、キーはオブジェクトの id() に基づいており、メソッドはクラスのインスタンスひとつずつ独立にトレースされます。将来、TensorFlowには、Python オブジェクトについて安全にテンソルに変換できるような、より洗練されたキャッシングが追加されるかもしれません。

引数は Python か? Tensor か?

しばしば、ハイパーパラメータやグラフ構成を制御するために Python の組み込み型の引数が使われます。例えば、num_layers=10training=True あるいは nonlinearity='relu' のようにです。このため、この Python の組み込み型の引数が変更されると、計算グラフを再びトレースする必要があるということになります。

しかし、グラフの生成を制御するために Python の組み込み型の引数を使用する必要はありません。これらのケースでは、Python引数の値の変更が不必要な再トレースを引き起こす可能性があります。例えば、この訓練ループでは、AutoGraph は動的に展開を行います。複数回トレースを行っていますが、生成される計算グラフは全く変わりません。これは少し非効率です。


In [ ]:
def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)

ここでの簡単な回避方法は、生成されたグラフの shape が変わらないのであれば、引数をテンソルにキャストすることです。


In [ ]:
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

tf.function の中の副作用

一般的には、(印字やオブジェクト変更のような)Python の副作用は、トレーシングの最中にだけ発生します。それでは、どうしたら tf.function で安定的に副作用を起こすことができるでしょうか?

一般的な原則は、トレースをデバッグする際にだけ Python の副作用を使用するというものです。あるいは、tf.Variable.assigntf.print、そして tf.summary のような TensorFlow の演算を使うことで、コードがトレースされるときにも、TensorFlowランタイムによって都度呼び出される際にも、確実に実行されるようにできます。一般には、関数型のスタイルを使用することで最も良い結果を得られます。


In [ ]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

tf.function が呼び出されるたびに Python のコードを実行したいのであれば、tf.py_function がぴったりです。tf.py_function の欠点は、ポータブルでないこと、それほど性能が高くないこと、(マルチGPU、TPUの)分散環境ではうまく動作しないことなどです。また、tf.py_function は計算グラフに組み込まれるため、入出力すべてをテンソルにキャストします。


In [ ]:
external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1

Python の状態に注意

ジェネレーターやイテレーターなど Python の機能の多くは、状態を追跡するために Python のランタイムに依存しています。これらの仕組みは、一般的には Eager モードでも期待通りに動作しますが、トレーシングの振る舞いにより、tf.function の中では予期しないことが起きることがあります。

1例として、イテレーターの状態が進むのは Python の副作用であり、トレーシングの中だけで発生します。


In [ ]:
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# 次のコードは、イテレーターの次の値を使うのではなく、最初の値を再利用する
buggy_consume_next(iterator)
buggy_consume_next(iterator)

イテレーターが tf.function の中で生成されすべて使われる場合には、正しく動作するはずです。しかし、イテレーター全体がトレースされることとなり、巨大な計算グラフの生成をまねく可能性があります。これは、望みどおりの動作かもしれません。しかし、もし Python のリストとして表されたメモリー上の巨大なデータセットを使って訓練を行うとすると、これは非常に大きな計算グラフを生成することになり、tf.function がスピードアップにはつながらないと考えられます。

Python データを繰り返し使用する場合、もっとも安全な方法は tf.data.Dataset でラップして、for x in y というイディオムを使用することです。AutoGraph には、y がテンソルあるいは tf.data.Dataset である場合、for ループを安全に変換する特別な機能があります。


In [ ]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # ダミー計算
  return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

Python/Numpy のデータを Dataset でラップする際には、tf.data.Dataset.from_generatortf.data.Dataset.from_tensors の違いに留意しましょう。前者はデータを Python のまま保持し tf.py_function を通じて取得するため、性能に影響する場合があります。これに対して後者はデータのコピーを計算グラフの中の、ひとつの大きな tf.constant() に結びつけるため、メモリー消費に影響する可能性があります。

TFRecordDataset/CsvDataset/などを通じてデータをファイルから読み込むことが、データを使用する最も効率的な方法です。TensorFlow 自身が Python とは関係なく非同期のデータ読み込みとプリフェッチを管理することができるからです。

自動的な依存関係の制御

プログラミングモデルとしての関数が一般的なデータフローグラフに対して非常に優位である点は、意図したコードの振る舞いがどのようなものであるかということについて、より多くの情報をランタイムに与えられるということにあります。

例えば、同じ変数を何度も読んだり書いたりするコードを書く場合、データフローグラフではもともと意図されていた演算の順番を自然に組み込むわけではありません。tf.function の中では、もともとの Python コードの文の実行順序を参照することで、実行順序の曖昧さを解消します。これにより、tf.function の中のステートフルな演算の順序が、先行実行モードのセマンティクスを模していることになります。

これは、手動で制御の依存関係を加える必要がないことを意味しています。tf.function は十分賢いので、あなたのコードが正しく動作するために必要十分な最小限の制御の依存関係を追加してくれます。


In [ ]:
# 自動的な依存関係の制御

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

変数

tf.function の中では、意図したコードの実行順序を活用するという同じアイデアを使って、変数の作成と活用を簡単に行うことができます。しかし、ひとつだけ非常に重要な欠点があります。それは、変数を使った場合、先行実行モードとグラフモードでは動作が変わるコードを書いてしまう可能性があるということです。

特に、呼び出しの都度新しい変数を作成する場合にこれが発生します。トレーシングの意味では、tf.function は呼び出しのたびに同じ変数を再利用しますが、Eager モードでは呼び出しごとに新しい変数を生成します。この間違いを防止するため、tf.function は危険な変数の生成動作を見つけるとエラーを発生させます。


In [ ]:
@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

In [ ]:
# しかし、曖昧さの無いコードは大丈夫

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

In [ ]:
# 初めて関数が実行されるときだけ変数が生成されることを保証できれば
# tf.function 内で変数を作成できる

class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0

In [ ]:
# 変数の初期化は、関数の引数や他の変数の値に依存可能
# 制御の依存関係を生成するのと同じ手法で、正しい初期化の順序を発見可能

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))

AutoGraph の使用

autograph ライブラリは tf.function に完全に統合されており、計算グラフの中で動的に実行される条件文や繰り返しを書くことができます。

tf.condtf.while_looptf.function でも使えますが、制御フローを含むコードは、命令形式で書いたほうが書きやすいし理解しやすいです。


In [ ]:
# 単純な繰り返し

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

In [ ]:
# 興味があれば AutoGraph が生成するコードを調べることができる
# ただし、アセンブリ言語を読むような感じがする

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))

AutoGraph: 条件分岐

AutoGraph は if 文を等価である tf.cond の呼び出しに変換します。

この置換は条件がテンソルである場合に行われます。そうでない場合には、条件分岐はトレーシングの中で実行されます。


In [ ]:
def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

In [ ]:
@tf.function
def hyperparam_cond(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

@tf.function
def maybe_tensor_cond(x):
  if x < 0:
    x = -x
  return x

test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1))
test_tf_cond(maybe_tensor_cond, -1)

tf.cond には、色々と注意すべき細かな点があります。

  • tf.cond は条件分岐の両方をトレーシングし、条件に従って実行時に適切な分岐を選択することで機能します。分岐の両方をトレースすることで、Python プログラムを予期せず実行する可能性があります。
  • tf.cond では、分岐の一方が後ほど使用されるテンソルを作成する場合、もう一方の分岐もそのテンソルを作成することが必要です。

In [ ]:
@tf.function
def f():
  x = tf.constant(0)
  if tf.constant(True):
    x = x + 1
    print("Tracing `then` branch")
  else:
    x = x - 1
    print("Tracing `else` branch")
  return x

f()

In [ ]:
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# 分岐のどちらの枝でも `x` を定義する必要があるためエラーが発生
with assert_raises(ValueError):
  f()

AutoGraph と繰り返し

AutoGraph には繰り返しの変換にいくつかの単純なルールがあります。

  • for: イテラブルがテンソルである場合に変換する
  • while: while 条件がテンソルに依存している場合に変換する

繰り返しが変換される場合、tf.while_loop によって動的に展開されます。あるいは、 for x in tf.data.Dataset という特別なケースの場合には、 tf.data.Dataset.reduce に変換されます。

繰り返しが変換されない場合、それは静的に展開されます。


In [ ]:
def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))

In [ ]:
@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)

In [ ]:
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)

In [ ]:
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)

In [ ]:
@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)

In [ ]:
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_tf_cond)

繰り返しに、テンソルに依存する break や、途中での return がある場合、一番外側の条件あるいはイテラブルはテンソルである必要があります。

比較してみましょう。


In [ ]:
@tf.function
def while_py_true_py_break(x):
  while True:  # py true
    if x == 0: # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)

In [ ]:
@tf.function
def buggy_while_py_true_tf_break(x):
  while True:   # py true
    if tf.equal(x, 0): # tf break
      break
    x -= 1
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)

In [ ]:
@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True): # tf true
    if x == 0:  # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)

In [ ]:
@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py for
    if tf.equal(i, 3): # tf break
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)

In [ ]:
@tf.function
def tf_for_py_break():
  x = 0
  for i in tf.range(5): # tf for
    if i == 3:  # py break
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_py_break)

動的に展開される繰り返しの結果を集計するため、tf.TensorArray を使いたくなるかもしれません。


In [ ]:
batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

tf.cond と同様に、tf.while_loop にも、色々と注意すべき細かな点があります。

  • 繰り返しの実行回数が 0 である可能性があるため、while_loop の後で使用されるテンソルは、繰り返しの前に初期化されなければならない
  • すべての繰り返しの変数は、各繰り返しを通じてその形状と dtype が変わらないことが必要

In [ ]:
@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i
  return x

with assert_raises(ValueError):
  buggy_loop_var_uninitialized()

In [ ]:
@tf.function
def f():
  x = tf.constant(0)
  for i in tf.range(3):
    x = i
  return x

f()

In [ ]:
@tf.function
def buggy_loop_type_changes():
  x = tf.constant(0, dtype=tf.float32)
  for i in tf.range(3): # tf.int32 型のテンソルを1つづつ取り出して…
    x = i
  return x

with assert_raises(tf.errors.InvalidArgumentError):
  buggy_loop_type_changes()

In [ ]:
@tf.function
def buggy_concat():
  x = tf.ones([0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones([1, 10])], axis=0)
  return x

with assert_raises(ValueError):
  buggy_concat()

In [ ]:
@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()