In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Migrate your TensorFlow 1 code to TensorFlow 2

View on TensorFlow.org View source on GitHub Download notebook

This doc for users of low level TensorFlow APIs. If you are using the high level APIs (tf.keras) there may be little or no action you need to take to make your code fully TensorFlow 2.0 compatible:

It is still possible to run 1.X code, unmodified (except for contrib), in TensorFlow 2.0:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

However, this does not let you take advantage of many of the improvements made in TensorFlow 2.0. This guide will help you upgrade your code, making it simpler, more performant, and easier to maintain.

Automatic conversion script

The first step, before attempting to implement the changes described in this doc, is to try running the upgrade script.

This will do an initial pass at upgrading your code to TensorFlow 2.0. But it can't make your code idiomatic to 2.0. Your code may still make use of tf.compat.v1 endpoints to access placeholders, sessions, collections, and other 1.x-style functionality.

Top-level behavioral changes

If your code works in TensorFlow 2.0 using tf.compat.v1.disable_v2_behavior(), there are still global behavioral changes you may need to address. The major changes are:

  • Eager execution, v1.enable_eager_execution() : Any code that implicitly uses a tf.Graph will fail. Be sure to wrap this code in a with tf.Graph().as_default() context.

  • Resource variables, v1.enable_resource_variables(): Some code may depends on non-deterministic behaviors enabled by TF reference variables. Resource variables are locked while being written to, and so provide more intuitive consistency guarantees.

    • This may change behavior in edge cases.
    • This may create extra copies and can have higher memory usage.
    • This can be disabled by passing use_resource=False to the tf.Variable constructor.
  • Tensor shapes, v1.enable_v2_tensorshape(): TF 2.0 simplifies the behavior of tensor shapes. Instead of t.shape[0].value you can say t.shape[0]. These changes should be small, and it makes sense to fix them right away. See TensorShape for examples.

  • Control flow, v1.enable_control_flow_v2(): The TF 2.0 control flow implementation has been simplified, and so produces different graph representations. Please file bugs for any issues.

Make the code 2.0-native

This guide will walk through several examples of converting TensorFlow 1.x code to TensorFlow 2.0. These changes will let your code take advantage of performance optimizations and simplified API calls.

In each case, the pattern is:

1. Replace v1.Session.run calls

Every v1.Session.run call should be replaced by a Python function.

  • The feed_dict and v1.placeholders become function arguments.
  • The fetches become the function's return value.
  • During conversion eager execution allows easy debugging with standard Python tools like pdb.

After that add a tf.function decorator to make it run efficiently in graph. See the Autograph Guide for more on how this works.

Note that:

  • Unlike v1.Session.run a tf.function has a fixed return signature, and always returns all outputs. If this causes performance problems, create two separate functions.

  • There is no need for a tf.control_dependencies or similar operations: A tf.function behaves as if it were run in the order written. tf.Variable assignments and tf.asserts, for example, are executed automatically.

2. Use Python objects to track variables and losses

All name-based variable tracking is strongly discouraged in TF 2.0. Use Python objects to to track variables.

Use tf.Variable instead of v1.get_variable.

Every v1.variable_scope should be converted to a Python object. Typically this will be one of:

  • tf.keras.layers.Layer
  • tf.keras.Model
  • tf.Module

If you need to aggregate lists of variables (like tf.Graph.get_collection(tf.GraphKeys.VARIABLES)), use the .variables and .trainable_variables attributes of the Layer and Model objects.

These Layer and Model classes implement several other properties that remove the need for global collections. Their .losses property can be a replacement for using the tf.GraphKeys.LOSSES collection.

See the keras guides for details.

Warning: Many tf.compat.v1 symbols use the global collections implicitly.

3. Upgrade your training loops

Use the highest level API that works for your use case. Prefer tf.keras.Model.fit over building your own training loops.

These high level functions manage a lot of the low-level details that might be easy to miss if you write your own training loop. For example, they automatically collect the regularization losses, and set the training=True argument when calling the model.

4. Upgrade your data input pipelines

Use tf.data datasets for data input. These objects are efficient, expressive, and integrate well with tensorflow.

They can be passed directly to the tf.keras.Model.fit method.

model.fit(dataset, epochs=5)

They can be iterated over directly standard Python:

for example_batch, label_batch in dataset:
    break

5. Migrate off compat.v1 symbols

The tf.compat.v1 module contains the complete TensorFlow 1.x API, with its original semantics.

The TF2 upgrade script will convert symbols to their 2.0 equivalents if such a conversion is safe, i.e., if it can determine that the behavior of the 2.0 version is exactly equivalent (for instance, it will rename v1.arg_max to tf.argmax, since those are the same function).

After the upgrade script is done with a piece of code, it is likely there are many mentions of compat.v1. It is worth going through the code and converting these manually to the 2.0 equivalent (it should be mentioned in the log if there is one).

Converting models

Setup


In [0]:
import tensorflow as tf


import tensorflow_datasets as tfds

Low-level variables & operator execution

Examples of low-level API use include:

  • using variable scopes to control reuse
  • creating variables with v1.get_variable.
  • accessing collections explicitly
  • accessing collections implicitly with methods like :

    • v1.global_variables
    • v1.losses.get_regularization_loss
  • using v1.placeholder to set up graph inputs

  • executing graphs with Session.run
  • initializing variables manually

Before converting

Here is what these patterns may look like in code using TensorFlow 1.x.

in_a = tf.placeholder(dtype=tf.float32, shape=(2))
in_b = tf.placeholder(dtype=tf.float32, shape=(2))

def forward(x):
  with tf.variable_scope("matmul", reuse=tf.AUTO_REUSE):
    W = tf.get_variable("W", initializer=tf.ones(shape=(2,2)),
                        regularizer=tf.contrib.layers.l2_regularizer(0.04))
    b = tf.get_variable("b", initializer=tf.zeros(shape=(2)))
    return W * x + b

out_a = forward(in_a)
out_b = forward(in_b)

reg_loss=tf.losses.get_regularization_loss(scope="matmul")

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outs = sess.run([out_a, out_b, reg_loss],
                feed_dict={in_a: [1, 0], in_b: [0, 1]})

After converting

In the converted code:

  • The variables are local Python objects.
  • The forward function still defines the calculation.
  • The Session.run call is replaced with a call to forward
  • The optional tf.function decorator can be added for performance.
  • The regularizations are calculated manually, without referring to any global collection.
  • No sessions or placeholders.

In [0]:
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")

@tf.function
def forward(x):
  return W * x + b

out_a = forward([1,0])
print(out_a)

In [0]:
out_b = forward([0,1])

regularizer = tf.keras.regularizers.l2(0.04)
reg_loss=regularizer(W)

Models based on tf.layers

The v1.layers module is used to contain layer-functions that relied on v1.variable_scope to define and reuse variables.

Before converting

def model(x, training, scope='model'):
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    x = tf.layers.conv2d(x, 32, 3, activation=tf.nn.relu,
          kernel_regularizer=tf.contrib.layers.l2_regularizer(0.04))
    x = tf.layers.max_pooling2d(x, (2, 2), 1)
    x = tf.layers.flatten(x)
    x = tf.layers.dropout(x, 0.1, training=training)
    x = tf.layers.dense(x, 64, activation=tf.nn.relu)
    x = tf.layers.batch_normalization(x, training=training)
    x = tf.layers.dense(x, 10)
    return x

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

After converting

  • The simple stack of layers fits neatly into tf.keras.Sequential. (For more complex models see custom layers and models, and the functional API.)
  • The model tracks the variables, and regularization losses.
  • The conversion was one-to-one because there is a direct mapping from v1.layers to tf.keras.layers.

Most arguments stayed the same. But notice the differences:

  • The training argument is passed to each layer by the model when it runs.
  • The first argument to the original model function (the input x) is gone. This is because object layers separate building the model from calling the model.

Also note that:

  • If you were using regularizers of initializers from tf.contrib, these have more argument changes than others.
  • The code no longer writes to collections, so functions like v1.losses.get_regularization_loss will no longer return these values, potentially breaking your training loops.

In [0]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.04),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

In [0]:
train_out = model(train_data, training=True)
print(train_out)

In [0]:
test_out = model(test_data, training=False)
print(test_out)

In [0]:
# Here are all the trainable variables.
len(model.trainable_variables)

In [0]:
# Here is the regularization loss.
model.losses

Mixed variables & v1.layers

Existing code often mixes lower-level TF 1.x variables and operations with higher-level v1.layers.

Before converting

def model(x, training, scope='model'):
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    W = tf.get_variable(
      "W", dtype=tf.float32,
      initializer=tf.ones(shape=x.shape),
      regularizer=tf.contrib.layers.l2_regularizer(0.04),
      trainable=True)
    if training:
      x = x + W
    else:
      x = x + W * 0.5
    x = tf.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
    x = tf.layers.max_pooling2d(x, (2, 2), 1)
    x = tf.layers.flatten(x)
    return x

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

After converting

To convert this code, follow the pattern of mapping layers to layers as in the previous example.

The general pattern is:

  • Collect layer parameters in __init__.
  • Build the variables in build.
  • Execute the calculations in call, and return the result.

The v1.variable_scope is essentially a layer of its own. So rewrite it as a tf.keras.layers.Layer. See the guide for details.


In [0]:
# Create a custom layer for part of the model
class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, *args, **kwargs):
    super(CustomLayer, self).__init__(*args, **kwargs)

  def build(self, input_shape):
    self.w = self.add_weight(
        shape=input_shape[1:],
        dtype=tf.float32,
        initializer=tf.keras.initializers.ones(),
        regularizer=tf.keras.regularizers.l2(0.02),
        trainable=True)

  # Call method will sometimes get used in graph mode,
  # training will get turned into a tensor
  @tf.function
  def call(self, inputs, training=None):
    if training:
      return inputs + self.w
    else:
      return inputs + self.w * 0.5

In [0]:
custom_layer = CustomLayer()
print(custom_layer([1]).numpy())
print(custom_layer([1], training=True).numpy())

In [0]:
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

# Build the model including the custom layer
model = tf.keras.Sequential([
    CustomLayer(input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
])

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

Some things to note:

  • Subclassed Keras models & layers need to run in both v1 graphs (no automatic control dependencies) and in eager mode

    • Wrap the call() in a tf.function() to get autograph and automatic control dependencies
  • Don't forget to accept a training argument to call.

    • Sometimes it is a tf.Tensor
    • Sometimes it is a Python boolean.
  • Create model variables in constructor or Model.build using self.add_weight().

    • In Model.build you have access to the input shape, so can create weights with matching shape.
    • Using tf.keras.layers.Layer.add_weight allows Keras to track variables and regularization losses.
  • Don't keep tf.Tensors in your objects.

    • They might get created either in a tf.function or in the eager context, and these tensors behave differently.
    • Use tf.Variables for state, they are always usable from both contexts
    • tf.Tensors are only for intermediate values.

A note on Slim & contrib.layers

A large amount of older TensorFlow 1.x code uses the Slim library, which was packaged with TensorFlow 1.x as tf.contrib.layers. As a contrib module, this is no longer available in TensorFlow 2.0, even in tf.compat.v1. Converting code using Slim to TF 2.0 is more involved than converting repositories that use v1.layers. In fact, it may make sense to convert your Slim code to v1.layers first, then convert to Keras.

  • Remove arg_scopes, all args need to be explicit
  • If you use them, split normalizer_fn and activation_fn into their own layers
  • Separable conv layers map to one or more different Keras layers (depthwise, pointwise, and separable Keras layers)
  • Slim and v1.layers have different arg names & default values
  • Some args have different scales
  • If you use Slim pre-trained models, try out Keras's pre-traimed models from tf.keras.applications or TF Hub's TF2 SavedModels exported from the original Slim code.

Some tf.contrib layers might not have been moved to core TensorFlow but have instead been moved to the TF add-ons package.

Training

There are many ways to feed data to a tf.keras model. They will accept Python generators and Numpy arrays as input.

The recommended way to feed data to a model is to use the tf.data package, which contains a collection of high performance classes for manipulating data.

If you are still using tf.queue, these are now only supported as data-structures, not as input pipelines.

Using Datasets

The TensorFlow Datasets package (tfds) contains utilities for loading predefined datasets as tf.data.Dataset objects.

For this example, load the MNISTdataset, using tfds:


In [0]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Then prepare the data for training:

  • Re-scale each image.
  • Shuffle the order of the examples.
  • Collect batches of images and labels.

In [0]:
BUFFER_SIZE = 10 # Use a much larger value for real code.
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

To keep the example short, trim the dataset to only return 5 batches:


In [0]:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)

In [0]:
image_batch, label_batch = next(iter(train_data))

Use Keras training loops

If you don't need low level control of your training process, using Keras's built-in fit, evaluate, and predict methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential, functional, or sub-classed).

The advantages of these methods include:

  • They accept Numpy arrays, Python generators and, tf.data.Datasets
  • They apply regularization, and activation losses automatically.
  • They support tf.distribute for multi-device training.
  • They support arbitrary callables as losses and metrics.
  • They support callbacks like tf.keras.callbacks.TensorBoard, and custom callbacks.
  • They are performant, automatically using TensorFlow graphs.

Here is an example of training a model using a Dataset. (For details on how this works see tutorials.)


In [0]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))

Write your own loop

If the Keras model's training step works for you, but you need more control outside that step, consider using the tf.keras.Model.train_on_batch method, in your own data-iteration loop.

Remember: Many things can be implemented as a tf.keras.callbacks.Callback.

This method has many of the advantages of the methods mentioned in the previous section, but gives the user control of the outer loop.

You can also use tf.keras.Model.test_on_batch or tf.keras.Model.evaluate to check performance during training.

Note: train_on_batch and test_on_batch, by default return the loss and metrics for the single batch. If you pass reset_metrics=False they return accumulated metrics and you must remember to appropriately reset the metric accumulators. Also remember that some metrics like AUC require reset_metrics=False to be calculated correctly.

To continue training the above model:


In [0]:
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

for epoch in range(NUM_EPOCHS):
  #Reset the metric accumulators
  model.reset_metrics()

  for image_batch, label_batch in train_data:
    result = model.train_on_batch(image_batch, label_batch)
    metrics_names = model.metrics_names
    print("train: ",
          "{}: {:.3f}".format(metrics_names[0], result[0]),
          "{}: {:.3f}".format(metrics_names[1], result[1]))
  for image_batch, label_batch in test_data:
    result = model.test_on_batch(image_batch, label_batch,
                                 # return accumulated metrics
                                 reset_metrics=False)
  metrics_names = model.metrics_names
  print("\neval: ",
        "{}: {:.3f}".format(metrics_names[0], result[0]),
        "{}: {:.3f}".format(metrics_names[1], result[1]))

Customize the training step

If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:

  1. Iterate over a Python generator or tf.data.Dataset to get batches of examples.
  2. Use tf.GradientTape to collect gradients.
  3. Use one of the tf.keras.optimizers to apply weight updates to the model's variables.

Remember:

  • Always include a training argument on the call method of subclassed layers and models.
  • Make sure to call the model with the training argument set correctly.
  • Depending on usage, model variables may not exist until the model is run on a batch of data.
  • You need to manually handle things like regularization losses for the model.

Note the simplifications relative to v1:

  • There is no need to run variable initializers. Variables are initialized on creation.
  • There is no need to add manual control dependencies. Even in tf.function operations act as in eager mode.

In [0]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)

New-style metrics and losses

In TensorFlow 2.0, metrics and losses are objects. These work both eagerly and in tf.functions.

A loss object is callable, and expects the (y_true, y_pred) as arguments:


In [0]:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()

A metric object has the following methods:

  • Metric.update_state() — add new observations
  • Metric.result() —get the current result of the metric, given the observed values
  • Metric.reset_states() — clear all observations.

The object itself is callable. Calling updates the state with new observations, as with update_state, and returns the new result of the metric.

You don't have to manually initialize a metric's variables, and because TensorFlow 2.0 has automatic control dependencies, you don't need to worry about those either.

The code below uses a metric to keep track of the mean loss observed within a custom training loop.


In [0]:
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))

Keras metric names

In TensorFlow 2.0 keras models are more consistent about handling metric names.

Now when you pass a string in the list of metrics, that exact string is used as the metric's name. These names are visible in the history object returned by model.fit, and in the logs passed to keras.callbacks. is set to the string you passed in the metric list.


In [0]:
model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)

In [0]:
history.history.keys()

This differs from previous versions where passing metrics=["accuracy"] would result in dict_keys(['loss', 'acc'])

Keras optimizers

The optimizers in v1.train, like v1.train.AdamOptimizer and v1.train.GradientDescentOptimizer, have equivalents in tf.keras.optimizers.

Convert v1.train to keras.optimizers

Here are things to keep in mind when converting your optimizers:

  • Upgrading your optimizers may make old checkpoints incompatible.
  • All epsilons now default to 1e-7 instead of 1e-8 (which is negligible in most use cases).
  • v1.train.GradientDescentOptimizer can be directly replaced by tf.keras.optimizers.SGD.
  • v1.train.MomentumOptimizer can be directly replaced by the SGD optimizer using the momentum argument: tf.keras.optimizers.SGD(..., momentum=...).
  • v1.train.AdamOptimizer can be converted to use tf.keras.optimizers.Adam. The beta1 and beta2 arguments have been renamed to beta_1 and beta_2.
  • v1.train.RMSPropOptimizer can be converted to tf.keras.optimizers.RMSprop. The decay argument has been renamed to rho.
  • v1.train.AdadeltaOptimizer can be converted directly to tf.keras.optimizers.Adadelta.
  • tf.train.AdagradOptimizer can be converted directly to tf.keras.optimizers.Adagrad.
  • tf.train.FtrlOptimizer can be converted directly to tf.keras.optimizers.Ftrl. The accum_name and linear_name arguments have been removed.
  • The tf.contrib.AdamaxOptimizer and tf.contrib.NadamOptimizer, can be converted directly to tf.keras.optimizers.Adamax and tf.keras.optimizers.Nadam. The beta1, and beta2 arguments have been renamed to beta_1 and beta_2.

New defaults for some tf.keras.optimizers

Warning: If you see a change in convergence behavior for your models, check the default learning rates.

There are no changes for optimizers.SGD, optimizers.Adam, or optimizers.RMSprop.

The following default learning rates have changed:

  • optimizers.Adagrad from 0.01 to 0.001
  • optimizers.Adadelta from 1.0 to 0.001
  • optimizers.Adamax from 0.002 to 0.001
  • optimizers.Nadam from 0.002 to 0.001

TensorBoard

TensorFlow 2 includes significant changes to the tf.summary API used to write summary data for visualization in TensorBoard. For a general introduction to the new tf.summary, there are several tutorials available that use the TF 2 API. This includes a TensorBoard TF 2 Migration Guide

Saving & Loading

Checkpoint compatibility

TensorFlow 2.0 uses object-based checkpoints.

Old-style name-based checkpoints can still be loaded, if you're careful. The code conversion process may result in variable name changes, but there are workarounds.

The simplest approach it to line up the names of the new model with the names in the checkpoint:

  • Variables still all have a name argument you can set.
  • Keras models also take a name argument as which they set as the prefix for their variables.
  • The v1.name_scope function can be used to set variable name prefixes. This is very different from tf.variable_scope. It only affects names, and doesn't track variables & reuse.

If that does not work for your use-case, try the v1.train.init_from_checkpoint function. It takes an assignment_map argument, which specifies the mapping from old names to new names.

Note: Unlike object based checkpoints, which can defer loading, name-based checkpoints require that all variables be built when the function is called. Some models defer building variables until you call build or run the model on a batch of data.

The TensorFlow Estimator repository includes a conversion tool to upgrade the checkpoints for premade estimators from TensorFlow 1.X to 2.0. It may serve as an example of how to build a tool for a similar use-case.

Saved models compatibility

There are no significant compatibility concerns for saved models.

  • TensorFlow 1.x saved_models work in TensorFlow 2.x.
  • TensorFlow 2.x saved_models work in TensorFlow 1.x—if all the ops are supported.

A Graph.pb or Graph.pbtxt

There is no straightforward way to upgrade a raw Graph.pb file to TensorFlow 2.0. Your best bet is to upgrade the code that generated the file.

But, if you have a "Frozen graph" (a tf.Graph where the variables have been turned into constants), then it is possible to convert this to a concrete_function using v1.wrap_function:


In [0]:
def wrap_frozen_graph(graph_def, inputs, outputs):
  def _imports_graph_def():
    tf.compat.v1.import_graph_def(graph_def, name="")
  wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
  import_graph = wrapped_import.graph
  return wrapped_import.prune(
      tf.nest.map_structure(import_graph.as_graph_element, inputs),
      tf.nest.map_structure(import_graph.as_graph_element, outputs))

For example, here is a frozed graph for Inception v1, from 2016:


In [0]:
path = tf.keras.utils.get_file(
    'inception_v1_2016_08_28_frozen.pb',
    'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
    untar=True)

Load the tf.GraphDef:


In [0]:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())

Wrap it into a concrete_function:


In [0]:
inception_func = wrap_frozen_graph(
    graph_def, inputs='input:0',
    outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')

Pass it a tensor as input:


In [0]:
input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape

Estimators

Training with Estimators

Estimators are supported in TensorFlow 2.0.

When you use estimators, you can use input_fn(), tf.estimator.TrainSpec, and tf.estimator.EvalSpec from TensorFlow 1.x.

Here is an example using input_fn with train and evaluate specs.

Creating the input_fn and train/eval specs


In [0]:
# Define the estimator's input_fn
def input_fn():
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000
  BATCH_SIZE = 64

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label[..., tf.newaxis]

  train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  return train_data.repeat()

# Define train & eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                    max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                  steps=STEPS_PER_EPOCH)

Using a Keras model definition

There are some differences in how to construct your estimators in TensorFlow 2.0.

We recommend that you define your model using Keras, then use the tf.keras.estimator.model_to_estimator utility to turn your model into an estimator. The code below shows how to use this utility when creating and training an estimator.


In [0]:
def make_model():
  return tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])

In [0]:
model = make_model()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(
  keras_model = model
)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

Note: We do not support creating weighted metrics in Keras and converting them to weighted metrics in the Estimator API using model_to_estimator You will have to create these metrics directly on the estimator spec using the add_metrics function.

Using a custom model_fn

If you have an existing custom estimator model_fn that you need to maintain, you can convert your model_fn to use a Keras model.

However, for compatibility reasons, a custom model_fn will still run in 1.x-style graph mode. This means there is no eager execution and no automatic control dependencies.

Custom model_fn with minimal changes

To make your custom model_fn work in TF 2.0, if you prefer minimal changes to the existing code, tf.compat.v1 symbols such as optimizers and metrics can be used.

Using a Keras models in a custom model_fn is similar to using it in a custom training loop:

  • Set the training phase appropriately, based on the mode argument.
  • Explicitly pass the model's trainable_variables to the optimizer.

But there are important differences, relative to a custom loop:

  • Instead of using Model.losses, extract the losses using Model.get_losses_for.
  • Extract the model's updates using Model.get_updates_for.

Note: "Updates" are changes that need to be applied to a model after each batch. For example, the moving averages of the mean and variance in a layers.BatchNormalization layer.

The following code creates an estimator from a custom model_fn, illustrating all of these concerns.


In [0]:
def my_model_fn(features, labels, mode):
  model = make_model()

  optimizer = tf.compat.v1.train.AdamOptimizer()
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  predictions = model(features, training=training)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_fn(labels, predictions) + tf.math.add_n(reg_losses)

  accuracy = tf.compat.v1.metrics.accuracy(labels=labels,
                                           predictions=tf.math.argmax(predictions, axis=1),
                                           name='acc_op')

  update_ops = model.get_updates_for(None) + model.get_updates_for(features)
  minimize_op = optimizer.minimize(
      total_loss,
      var_list=model.trainable_variables,
      global_step=tf.compat.v1.train.get_or_create_global_step())
  train_op = tf.group(minimize_op, update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op, eval_metric_ops={'accuracy': accuracy})

# Create the Estimator & Train
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

Custom model_fn with TF 2.0 symbols

If you want to get rid of all TF 1.x symbols and upgrade your custom model_fn to native TF 2.0, you need to update the optimizer and metrics to tf.keras.optimizers and tf.keras.metrics.

In the custom model_fn, besides the above changes, more upgrades need to be made:

  • Use tf.keras.optimizers instead of v1.train.Optimizer.
  • Explicitly pass the model's trainable_variables to the tf.keras.optimizers.
  • To compute the train_op/minimize_op,
    • Use Optimizer.get_updates() if the loss is scalar loss Tensor(not a callable). The first element in the returned list is the desired train_op/minimize_op.
    • If the loss is a callable (such as a function), use Optimizer.minimize() to get the train_op/minimize_op.
  • Use tf.keras.metrics instead of tf.compat.v1.metrics for evaluation.

For the above example of my_model_fn, the migrated code with 2.0 symbols is shown as:


In [0]:
def my_model_fn(features, labels, mode):
  model = make_model()

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  predictions = model(features, training=training)

  # Get both the unconditional losses (the None part)
  # and the input-conditional losses (the features part).
  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_obj(labels, predictions) + tf.math.add_n(reg_losses)

  # Upgrade to tf.keras.metrics.
  accuracy_obj = tf.keras.metrics.Accuracy(name='acc_obj')
  accuracy = accuracy_obj.update_state(
      y_true=labels, y_pred=tf.math.argmax(predictions, axis=1))

  train_op = None
  if training:
    # Upgrade to tf.keras.optimizers.
    optimizer = tf.keras.optimizers.Adam()
    # Manually assign tf.compat.v1.global_step variable to optimizer.iterations
    # to make tf.compat.v1.train.global_step increased correctly.
    # This assignment is a must for any `tf.train.SessionRunHook` specified in
    # estimator, as SessionRunHooks rely on global step.
    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
    # Get both the unconditional updates (the None part)
    # and the input-conditional updates (the features part).
    update_ops = model.get_updates_for(None) + model.get_updates_for(features)
    # Compute the minimize_op.
    minimize_op = optimizer.get_updates(
        total_loss,
        model.trainable_variables)[0]
    train_op = tf.group(minimize_op, *update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op,
    eval_metric_ops={'Accuracy': accuracy_obj})

# Create the Estimator & Train.
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

Premade Estimators

Premade Estimators in the family of tf.estimator.DNN*, tf.estimator.Linear* and tf.estimator.DNNLinearCombined* are still supported in the TensorFlow 2.0 API, however, some arguments have changed:

  1. input_layer_partitioner: Removed in 2.0.
  2. loss_reduction: Updated to tf.keras.losses.Reduction instead of tf.compat.v1.losses.Reduction. Its default value is also changed to tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE from tf.compat.v1.losses.Reduction.SUM.
  3. optimizer, dnn_optimizer and linear_optimizer: this arg has been updated to tf.keras.optimizers instead of the tf.compat.v1.train.Optimizer.

To migrate the above changes:

  1. No migration is needed for input_layer_partitioner since Distribution Strategy will handle it automatically in TF 2.0.
  2. For loss_reduction, check tf.keras.losses.Reduction for the supported options.
  3. For optimizer args, if you do not pass in an optimizer, dnn_optimizer or linear_optimizer arg, or if you specify the optimizer arg as a string in your code, you don't need to change anything. tf.keras.optimizers is used by default. Otherwise, you need to update it from tf.compat.v1.train.Optimizer to its corresponding tf.keras.optimizers

Checkpoint Converter

The migration to keras.optimizers will break checkpoints saved using TF 1.x, as tf.keras.optimizers generates a different set of variables to be saved in checkpoints. To make old checkpoint reusable after your migration to TF 2.0, try the checkpoint converter tool.


In [0]:
! curl -O https://raw.githubusercontent.com/tensorflow/estimator/master/tensorflow_estimator/python/estimator/tools/checkpoint_converter.py

The tool has builtin help:


In [0]:
! python checkpoint_converter.py -h

TensorShape

This class was simplified to hold ints, instead of tf.compat.v1.Dimension objects. So there is no need to call .value() to get an int.

Individual tf.compat.v1.Dimension objects are still accessible from tf.TensorShape.dims.

The following demonstrate the differences between TensorFlow 1.x and TensorFlow 2.0.


In [0]:
# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape

If you had this in TF 1.x:

value = shape[i].value

Then do this in TF 2.0:


In [0]:
value = shape[i]
value

If you had this in TF 1.x:

for dim in shape:
    value = dim.value
    print(value)

Then do this in TF 2.0:


In [0]:
for value in shape:
  print(value)

If you had this in TF 1.x (Or used any other dimension method):

dim = shape[i]
dim.assert_is_compatible_with(other_dim)

Then do this in TF 2.0:


In [0]:
other_dim = 16
Dimension = tf.compat.v1.Dimension

if shape.rank is None:
  dim = Dimension(None)
else:
  dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method

In [0]:
shape = tf.TensorShape(None)

if shape:
  dim = shape.dims[i]
  dim.is_compatible_with(other_dim) # or any other dimension method

The boolean value of a tf.TensorShape is True if the rank is known, False otherwise.


In [0]:
print(bool(tf.TensorShape([])))      # Scalar
print(bool(tf.TensorShape([0])))     # 0-length vector
print(bool(tf.TensorShape([1])))     # 1-length vector
print(bool(tf.TensorShape([None])))  # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100])))       # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None)))  # A tensor with unknown rank.

Other Changes

  • Remove tf.colocate_with: TensorFlow's device placement algorithms have improved significantly. This should no longer be necessary. If removing it causes a performance degredation please file a bug.

  • Replace v1.ConfigProto usage with the equivalent functions from tf.config.

Conclusions

The overall process is:

  1. Run the upgrade script.
  2. Remove contrib symbols.
  3. Switch your models to an object oriented style (Keras).
  4. Use tf.keras or tf.estimator training and evaluation loops where you can.
  5. Otherwise, use custom loops, but be sure to avoid sessions & collections.

It takes a little work to convert code to idiomatic TensorFlow 2.0, but every change results in:

  • Fewer lines of code.
  • Increased clarity and simplicity.
  • Easier debugging.