Licensed under the Apache License, Version 2.0 (the "License");


In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

AutoGraph: Easy control flow for graphs

Note: This is an archived TF1 notebook. These are configured to run in TF2's compatbility mode but will run in TF1 as well. To use TF1 in Colab, use the magic.

AutoGraph helps you write complicated graph code using normal Python. Behind the scenes, AutoGraph automatically transforms your code into the equivalent TensorFlow graph code. AutoGraph already supports much of the Python language, and that coverage continues to grow. For a list of supported Python language features, see the Autograph capabilities and limitations.

Setup

Import TensorFlow, AutoGraph, and any supporting modules:


In [ ]:
import tensorflow.compat.v1 as tf

In [ ]:
layers = tf.keras.layers

import numpy as np
import matplotlib.pyplot as plt

We'll enable eager execution for demonstration purposes, but AutoGraph works in both eager and graph execution environments:

Note: AutoGraph converted code is designed to run during graph execution. When eager exectuon is enabled, use explicit graphs (as this example shows) or tf.contrib.eager.defun.

Automatically convert Python control flow

AutoGraph will convert much of the Python language into the equivalent TensorFlow graph building code.

Note: In real applications batching is essential for performance. The best code to convert to AutoGraph is code where the control flow is decided at the batch level. If making decisions at the individual example level, you must index and batch the examples to maintain performance while applying the control flow logic.

AutoGraph converts a function like:


In [ ]:
def square_if_positive(x):
  if x > 0:
    x = x * x
  else:
    x = 0.0
  return x

To a function that uses graph building:


In [ ]:
print(tf.autograph.to_code(square_if_positive))

Code written for eager execution can run in a tf.Graph with the same results, but with the benefits of graph execution:


In [ ]:
print('Eager results: %2.2f, %2.2f' % (square_if_positive(tf.constant(9.0)),
                                       square_if_positive(tf.constant(-9.0))))

Generate a graph-version and call it:


In [ ]:
tf_square_if_positive = tf.autograph.to_graph(square_if_positive)

with tf.Graph().as_default():
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  g_out1 = tf_square_if_positive(tf.constant( 9.0))
  g_out2 = tf_square_if_positive(tf.constant(-9.0))
  with tf.Session() as sess:
    print('Graph results: %2.2f, %2.2f\n' % (sess.run(g_out1), sess.run(g_out2)))

AutoGraph supports common Python statements like while, for, if, break, and return, with support for nesting. Compare this function with the complicated graph verson displayed in the following code blocks:


In [ ]:
# Continue in a loop
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
      continue
    s += c
  return s

print('Eager result: %d' % sum_even(tf.constant([10,12,15,20])))

tf_sum_even = tf.autograph.to_graph(sum_even)

with tf.Graph().as_default(), tf.Session() as sess:
    print('Graph result: %d\n\n' % sess.run(tf_sum_even(tf.constant([10,12,15,20]))))

In [ ]:
print(tf.autograph.to_code(sum_even))

tf.function

Use the tf.function decorator:


In [ ]:
@tf.function(
    experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS)
def fizzbuzz(i, n):
  while i < n:
    msg = ''
    if i % 3 == 0:
      msg += 'Fizz'
    if i % 5 == 0:
      msg += 'Buzz'
    if msg == '':
      msg = tf.as_string(i)
    tf.print(msg)
    i += 1
  return i

with tf.Graph().as_default():
  final_i = fizzbuzz(tf.constant(10), tf.constant(16))
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  with tf.Session() as sess:
    sess.run(final_i)

Examples

Let's demonstrate some useful Python language features.

Assert

AutoGraph can automatically convert the Python assert statement into the equivalent tf.Assert code:


In [ ]:
@tf.function(
    experimental_autograph_options=(
        tf.autograph.experimental.Feature.ASSERT_STATEMENTS,
        tf.autograph.experimental.Feature.EQUALITY_OPERATORS))
def inverse(x):
  assert x != 0.0, 'Do not pass zero!'
  return 1.0 / x

with tf.Graph().as_default(), tf.Session() as sess:
  try:
    print(sess.run(inverse(tf.constant(0.0))))
  except tf.errors.InvalidArgumentError as e:
    print('Got error message:\n    %s' % e.message)

Print

Optionally, you may use the Python print function in-graph, when combined with the automatic control dependency management of tf.function:


In [ ]:
@tf.function(
    experimental_autograph_options=tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS)
def count(n):
  i = 0
  while i < n:
    print(i)
    i += 1
  return n

with tf.Graph().as_default(), tf.Session() as sess:
    sess.run(count(tf.constant(5)))

Lists

Append to lists in loops (tensor list ops are automatically created):


In [ ]:
@tf.function(
    experimental_autograph_options=tf.autograph.experimental.Feature.LISTS)
def arange(n):
  z = tf.TensorArray(tf.int32, size=0, dynamic_size=True)

  for i in tf.range(n):
    z.append(i)

  return z.stack()


with tf.Graph().as_default(), tf.Session() as sess:
    print(sess.run(arange(tf.constant(10))))

Nested control flow


In [ ]:
@tf.function(
    experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS)
def nearest_odd_square(x):
  if x > 0:
    x = x * x
    if x % 2 == 0:
      x = x + 1
  return x

with tf.Graph().as_default():
  with tf.Session() as sess:
    print(sess.run(nearest_odd_square(tf.constant(4))))
    print(sess.run(nearest_odd_square(tf.constant(5))))
    print(sess.run(nearest_odd_square(tf.constant(6))))

While loop


In [ ]:
@tf.function
def square_until_stop(x, y):
  while x < y:
    x = x * x
  return x

with tf.Graph().as_default():
  with tf.Session() as sess:
    print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))

For loop


In [ ]:
@tf.function(
    experimental_autograph_options=tf.autograph.experimental.Feature.LISTS)
def squares(nums):

  result = tf.TensorArray(tf.int64, size=0, dynamic_size=True)
  
  for num in nums:
    result.append(num * num)

  return result.stack()

with tf.Graph().as_default():
  with tf.Session() as sess:
    print(sess.run(squares(tf.constant(np.arange(10)))))

Break


In [ ]:
@tf.function
def argwhere_cumsum(x, threshold):
  current_sum = 0.0
  idx = 0
  for i in tf.range(len(x)):
    idx = i
    if current_sum >= threshold:
      break
    current_sum += x[i]
  return idx

N = 10
with tf.Graph().as_default():
  with tf.Session() as sess:
    idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))
    print(sess.run(idx))

Interoperation with tf.Keras

It's easy to integrate tf.autograph with tf.keras.

Stateless functions

For stateless functions, like collatz shown below, the easiest way to include them in a keras model is to wrap them up as a layer using tf.keras.layers.Lambda.


In [ ]:
import numpy as np

@tf.function(
    experimental_autograph_options=(
        tf.autograph.experimental.Feature.ASSERT_STATEMENTS,
        tf.autograph.experimental.Feature.EQUALITY_OPERATORS,
        ))
def collatz(x):
  x = tf.reshape(x,())
  assert x > 0
  n = tf.convert_to_tensor((0,))
  while x != 1:
    n += 1
    if x % 2 == 0:
      x = x // 2
    else:
      x = 3 * x + 1

  return n

with tf.Graph().as_default():
  model = tf.keras.Sequential([
    tf.keras.layers.Lambda(collatz, input_shape=(1,), output_shape=())
  ])

  result = model.predict(np.array([6171]))
  print(result)

Advanced Custom Models

For subclasses of Keras models, the easiest way is to convert their call method. See the TensorFlow Keras guide for details on how to build on these classes.

Here is a simple example of the stochastic network depth technique :


In [ ]:
# `K` is used to check if we're in train or test mode.
K = tf.keras.backend

class StochasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, layers, pfirst=1.0, plast=0.5,**kwargs):
    self.pfirst = pfirst
    self.plast = plast
    super(StochasticNetworkDepth, self).__init__(layers,**kwargs)

  def build(self, input_shape):
    self.depth = len(self.layers)
    self.plims = np.linspace(self.pfirst, self.plast, self.depth + 1)[:-1]
    super(StochasticNetworkDepth, self).build(input_shape.as_list())

  def call(self, inputs):
    training = tf.cast(K.learning_phase(), dtype=bool)
    if not training:
      count = self.depth
      return super(StochasticNetworkDepth, self).call(inputs), count

    p = tf.random_uniform((self.depth,))

    keeps = (p <= self.plims)
    x = inputs

    count = tf.reduce_sum(tf.cast(keeps, tf.int32))
    for i in range(self.depth):
      if keeps[i]:
        x = self.layers[i](x)

    # return both the final-layer output and the number of layers executed.
    return x, count

StochasticNetworkDepth.call = tf.autograph.to_graph(StochasticNetworkDepth.call)

Let's try it on mnist-shaped data:


In [ ]:
train_batch = np.random.randn(64, 28, 28, 1).astype(np.float32)

Build a simple stack of conv layers, in the stochastic depth model:


In [ ]:
with tf.Graph().as_default() as g:
  model = StochasticNetworkDepth(
      [
        layers.Conv2D(filters=16, activation=tf.nn.relu,
                  kernel_size=(3, 3), padding='same')
        for n in range(20)
      ],
      pfirst=1.0, plast=0.5
  )

  model.build(tf.TensorShape((None, None, None, 1)))

  init = tf.global_variables_initializer()

Now test it to ensure it behaves as expected in train and test modes:


In [ ]:
# Use an explicit session here so we can set the train/test switch, and
# inspect the layer count returned by `call`
with tf.Session(graph=g) as sess:
  init.run()

  for phase, name in enumerate(['test','train']):
    K.set_learning_phase(phase)
    result, count = model(tf.convert_to_tensor(train_batch, dtype=tf.float32))

    result1, count1 = sess.run((result, count))
    result2, count2 = sess.run((result, count))

    delta = (result1 - result2)
    print(name, "sum abs delta: ", abs(delta).mean())
    print("    layers 1st call: ", count1)
    print("    layers 2nd call: ", count2)
    print()

Advanced example: An in-graph training loop

The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.

Since writing control flow in AutoGraph is easy, running a training loop in a TensorFlow graph should also be easy.

This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.

Download data


In [ ]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

Define the model


In [ ]:
def mlp_model(input_shape):
  model = tf.keras.Sequential((
      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
      tf.keras.layers.Dense(100, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')))
  model.build()
  return model


def predict(m, x, y):
  x = tf.to_float(x) / 255.0
  y = tf.one_hot(tf.squeeze(y), 10)
  y_p = m(tf.reshape(x, (-1, 28 * 28)))
  losses = tf.keras.losses.categorical_crossentropy(y, y_p)
  l = tf.reduce_mean(losses)
  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
  accuracy = tf.reduce_mean(accuracies)
  return l, accuracy


def fit(m, x, y, opt):
  l, accuracy = predict(m, x, y)
  # Autograph automatically adds the necessary `tf.control_dependencies` here.
  # (Without them nothing depends on `opt.minimize`, so it doesn't run.)
  # This makes it much more like eager-code.
  opt.minimize(l)
  return l, accuracy


def setup_mnist_data(is_training, batch_size):
  if is_training:
    ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
    ds = ds.shuffle(batch_size * 10)
  else:
    ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

  ds = ds.repeat()
  ds = ds.batch(batch_size)
  return ds

Define the training loop


In [ ]:
def train(train_ds, test_ds, learning_rate, max_steps):
  m = mlp_model((28 * 28,))
  opt = tf.train.AdamOptimizer(learning_rate)

  train_losses = tf.TensorArray(tf.float32, size=0, dynamic_size=True, element_shape=())
  test_losses = tf.TensorArray(tf.float32, size=0, dynamic_size=True, element_shape=())
  train_accuracies = tf.TensorArray(tf.float32, size=0, dynamic_size=True, element_shape=())
  test_accuracies = tf.TensorArray(tf.float32, size=0, dynamic_size=True, element_shape=())

  # This entire training loop will be run in-graph.
  i = tf.constant(0)
  for (train_x, train_y), (test_x, test_y) in tf.data.Dataset.zip((train_ds, test_ds)):
    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
    if i % 50 == 0:
      print('Step', i, 'train loss:', step_train_loss, 'test loss:',
            step_test_loss, 'train accuracy:', step_train_accuracy,
            'test accuracy:', step_test_accuracy)
    train_losses.append(step_train_loss)
    test_losses.append(step_test_loss)
    train_accuracies.append(step_train_accuracy)
    test_accuracies.append(step_test_accuracy)

    i += 1
    if i >= max_steps:
      break

  # We've recorded our loss values and accuracies
  # to a list in a graph with AutoGraph's help.
  # In order to return the values as a Tensor,
  # we need to stack them before returning them.
  return (train_losses.stack(), test_losses.stack(),
          train_accuracies.stack(), test_accuracies.stack())
  
train = tf.autograph.to_graph(
    train,
    experimental_optional_features=(
        tf.autograph.experimental.Feature.LISTS,
        tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS,
        tf.autograph.experimental.Feature.EQUALITY_OPERATORS,
        tf.autograph.experimental.Feature.AUTO_CONTROL_DEPS))

Now build the graph and run the training loop:


In [ ]:
with tf.Graph().as_default() as g:
  learning_rate = 0.005
  max_steps=500

  train_ds = setup_mnist_data(True, 50)
  test_ds = setup_mnist_data(False, 1000)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = train(train_ds, test_ds, learning_rate, max_steps)

  init = tf.global_variables_initializer()

with tf.Session(graph=g) as sess:
  sess.run(init)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,
                                test_accuracies])

In [ ]:
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.legend()
plt.xlabel('Training step')
plt.ylabel('Loss')
plt.show()
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')
plt.ylabel('Accuracy')
plt.show()