In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
|
|
In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.
You can use tf.function
to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel
.
This guide will help you conceptualize how tf.function
works under the hood so you can use it effectively.
The main takeaways and recommendations are:
@tf.function
.tf.function
works best with TensorFlow ops; NumPy and Python calls are converted to constants.
In [ ]:
# Install tf-nightly for `pretty_printed_concrete_signatures`
# it will be in TF2.3
!pip install tf-nightly
In [ ]:
import tensorflow as tf
Define a helper function to demonstrate the kinds of errors you might encounter:
In [ ]:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
In [ ]:
@tf.function
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
In [ ]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
You can use Function
s inside other Function
s.
In [ ]:
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
Function
s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.
In [ ]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Python's dynamic typing means that you can call functions with a variety of argument types, and Python can do something different in each scenario.
Yet, to create a TensorFlow Graph, static dtypes
and shape dimensions are required. tf.function
bridges this gap by wrapping a Python function to create a Function
object. Based on the given inputs, the Function
selects the appropriate graph for the given inputs, retracing the Python function as necessary. Once you understand why and when tracing happens, it's much easier to use tf.function
effectively!
You can call a Function
with arguments of different types to see this polymorphic behavior in action.
In [ ]:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Note that if you repeatedly call a Function
with the same argument type, TensorFlow will reuse a previously traced graph, as the generated graph would be identical.
In [ ]:
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
(The following change is available in TensorFlow nightly, and will be available in TensorFlow 2.3.)
You can use pretty_printed_concrete_signatures()
to see all of the available traces:
In [ ]:
print(double.pretty_printed_concrete_signatures())
So far, you've seen that tf.function
creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:
tf.Graph
is the raw, language-agnostic, portable representation of your computation.ConcreteFunction
is an eagerly-executing wrapper around a tf.Graph
.Function
manages a cache of ConcreteFunction
s and picks the right one for your inputs.tf.function
wraps a Python function, returning a Function
object.
In [ ]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
In [ ]:
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
(The following change is available in TensorFlow nightly, and will be available in TensorFlow 2.3.)
Printing a ConcreteFunction
displays a summary of its input arguments (with types) and its output type.
In [ ]:
print(double_strings)
You can also directly retrieve a concrete function's signature.
In [ ]:
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
Using a concrete trace with incompatible types will throw an error
In [ ]:
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing.
In [ ]:
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
In [ ]:
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
In [ ]:
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
In general, debugging code is easier in eager mode than inside tf.function
. You should ensure that your code executes error-free in eager mode before decorating with tf.function
. To assist in the debugging process, you can call tf.config.run_functions_eagerly(True)
to globally disable and reenable tf.function
.
When tracking down issues that only appear within tf.function
, here are some tips:
print
calls only execute during tracing, helping you track down when your function gets (re)traced.tf.print
calls will execute every time, and can help you track down intermediate values during execution.tf.debugging.enable_check_numerics
is an easy way to track down where NaNs and Inf are created.pdb
can help you understand what's going on during tracing. (Caveat: PDB will drop you into AutoGraph-transformed source code.)A Function
determines whether to reuse a traced concrete function by computing a cache key from an input's args and kwargs.
tf.Tensor
argument is its shape and dtype.tf.Variable
argument is its id()
.dict
s, list
s, tuple
s, namedtuple
s, and attr
s is the flattened tuple. (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).id()
so that methods are traced independently for each instance of a class.Retracing helps ensures that TensorFlow generates correct graphs for each set of inputs. However, tracing is an expensive operation! If your Function
retraces a new graph for every call, you'll find that your code executes more slowly than if you didn't use tf.function
.
To control the tracing behavior, you can use the following techniques:
input_signature
in tf.function
to limit tracing.
In [ ]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Specify a [None] dimension in tf.TensorSpec
to allow for flexibility in trace reuse.
Since TensorFlow matches tensors based on their shape, using a None
dimension as a wildcard will allow Function
s to reuse traces for variably-sized input. Variably-sized input can occur if you have sequences of different length, or images of different sizes for each batch (See Transformer and Deep Dream tutorials for example).
In [ ]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Cast Python arguments to Tensors to reduce retracing.
Often, Python arguments are used to control hyperparameters and graph constructions - for example, num_layers=10
or training=True
or nonlinearity='relu'
. So if the Python argument changes, it makes sense that you'd have to retrace the graph.
However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary.
In [ ]:
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
If you need to force retracing, create a new Function
. Separate Function
objects are guaranteed not to share traces.
In [ ]:
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Python side effects like printing, appending to lists, and mutating globals only happen the first time you call a Function
with a set of inputs. Afterwards, the traced tf.Graph
is reexecuted, without executing the Python code.
The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow ops like tf.Variable.assign
, tf.print
, and tf.summary
are the best way to ensure your code will be traced and executed by the TensorFlow runtime with each call.
In [ ]:
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, many unexpected things can happen inside a Function
.
To give one example, advancing iterator state is a Python side effect and therefore only happens during tracing.
In [ ]:
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
external_var.assign_add(next(iterator))
tf.print("Value of external_var:", external_var)
iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Some iteration constructs are supported through AutoGraph. See the section on AutoGraph Transformations for an overview.
If you would like to execute Python code during each invocation of a Function
, tf.py_function
is an exit hatch. The drawback of tf.py_function
is that it's not portable or particularly performant, nor does it work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function
has to be wired into the graph, it casts all inputs/outputs to tensors.
APIs like tf.gather
, tf.stack
, and tf.TensorArray
can help you implement common looping patterns in native TensorFlow.
In [ ]:
external_list = []
def side_effect(x):
print('Python side effect')
external_list.append(x)
@tf.function
def f(x):
tf.py_function(side_effect, inp=[x], Tout=[])
f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1
In [ ]:
@tf.function
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
with assert_raises(ValueError):
f(1.0)
You can create variables inside a Function
as long as those variables are only created the first time the function is executed.
In [ ]:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
Another error you may encounter is a garbage-collected variable. Unlike normal Python functions, concrete functions only retain WeakRefs to the variables they close over, so you must retain a reference to any variables.
In [ ]:
external_var = tf.Variable(3)
@tf.function
def f(x):
return x * external_var
traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))
del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
traced_f(4)
AutoGraph is a library that is on by default in tf.function
, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like if
, for
, while
.
TensorFlow ops like tf.cond
and tf.while_loop
continue to work, but control flow is often easier to write and understand when written in Python.
In [ ]:
# Simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
If you're curious you can inspect the code autograph generates.
In [ ]:
print(tf.autograph.to_code(f.python_function))
AutoGraph will convert some if <condition>
statements into the equivalent tf.cond
calls. This substitution is made if <condition>
is a Tensor. Otherwise, the if
statement is executed as a Python conditional.
A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.
tf.cond
traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see AutoGraph tracing effects for more.
In [ ]:
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
See the reference documentation for additional restrictions on AutoGraph-converted if statements.
AutoGraph will convert some for
and while
statements into the equivalent TensorFlow looping ops, like tf.while_loop
. If not converted, the for
or while
loop is executed as a Python loop.
This substitution is made in the following situations:
for x in y
: if y
is a Tensor, convert to tf.while_loop
. In the special case where y
is a tf.data.Dataset
, a combination of tf.data.Dataset
ops are generated.while <condition>
: if <condition>
is a Tensor, convert to tf.while_loop
.A Python loop executes during tracing, adding additional ops to the tf.Graph
for every iteration of the loop.
A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time. The loop body only appears once in the generated tf.Graph
.
See the reference documentation for additional restrictions on AutoGraph-converted for
and while
statements.
A common pitfall is to loop over Python/Numpy data within a tf.function
. This loop will execute during the tracing process, adding a copy of your model to the tf.Graph
for each iteration of the loop.
If you want to wrap the entire training loop in tf.function
, the safest way to do this is to wrap your data as a tf.data.Dataset
so that AutoGraph will dynamically unroll the training loop.
In [ ]:
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
When wrapping Python/Numpy data in a Dataset, be mindful of tf.data.Dataset.from_generator
versus tf.data.Dataset.from_tensors
. The former will keep the data in Python and fetch it via tf.py_function
which can have performance implications, whereas the latter will bundle a copy of the data as one large tf.constant()
node in the graph, which can have memory implications.
Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the tf.data guide.
A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use tf.TensorArray
to accumulate results from a dynamically unrolled loop.
In [ ]:
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
To learn about how to export and load a Function
, see the SavedModel guide. To learn more about graph optimizations that are performed after tracing, see the Grappler guide. To learn how to optimize your data pipeline and profile your model, see the Profiler guide.