MNIST Model TensorFlow Training, IREE Execution


This notebook creates and trains a TensorFlow 2.0 model for recognizing handwritten digits using the MNIST dataset, then compiles and executes that trained model using IREE.

Running Locally

  • Refer to iree/docs/ for general information
  • Ensure that you have a recent version of TensorFlow 2.0 installed on your system
  • Enable IREE/TF integration by adding to your user.bazelrc: build --define=iree_tensorflow=true
  • Start colab by running python colab/ (see that file for additional instructions)
  • Note: you may need to restart your runtime in order to re-run certain cells. Some of the APIs are not yet stable enough for repeated invocations

Setup Steps

In [2]:
import os
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from import compiler as ireec
from pyiree import rt as ireert


SAVE_PATH = os.path.join(os.environ["HOME"], "saved_models")
os.makedirs(SAVE_PATH, exist_ok=True)

# Print version information for future notebook users to reference.
print("TensorFlow version: ", tf.__version__)
print("Numpy version: ", np.__version__)

TensorFlow version:  2.5.0-dev20200626
Numpy version:  1.18.4

In [3]:
#@title Notebook settings { run: "auto" }

#@markdown -----
#@markdown ### Configuration

backend_choice = "GPU (vulkan-spirv)" #@param [ "GPU (vulkan-spirv)", "CPU (VMLA)" ]

if backend_choice == "GPU (vulkan-spirv)":
  backend_name = "vulkan-spirv"
  driver_name = "vulkan"
  backend_name = "vmla"
  driver_name = "vmla"
tf.print("Using IREE compiler backend '%s' and runtime driver '%s'" % (backend_name, driver_name))

#@markdown -----
#@markdown ### Training Parameters

#@markdown <sup>Batch size used to subdivide the training and evaluation samples</sup>
batch_size = 200  #@param { type: "slider", min: 10, max: 400 }

#@markdown <sup>Epochs for training/eval. Higher values take longer to run but generally produce more accurate models</sup>
num_epochs = 5    #@param { type: "slider", min:  1, max:  20 }

#@markdown -----

Using IREE compiler backend 'vulkan-spirv' and runtime driver 'vulkan'

Create and Train MNIST Model in TensorFlow

The specific details of the training process here aren't critical to the model compilation and execution through IREE.

In [4]:
#@title Load MNIST dataset, setup training and evaluation

NUM_CLASSES = 10  # One per digit [0, 1, 2, ..., 9]

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
tf.print("Loaded MNIST dataset!")

x_train = x_train.reshape(x_train.shape[0], IMG_ROWS, IMG_COLS, 1)
x_test = x_test.reshape(x_test.shape[0], IMG_ROWS, IMG_COLS, 1)
input_shape = (IMG_ROWS, IMG_COLS, 1)

# Scale pixel values from [0, 255] integers to [0.0, 1.0] floats.
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

steps_per_epoch = int(x_train.shape[0] / batch_size)
steps_per_eval = int(x_test.shape[0] / batch_size)

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

# Construct batched datasets for training/evaluation.
train_dataset =, y_train))
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
test_dataset =, y_test))
test_dataset = test_dataset.batch(batch_size, drop_remainder=True)

# Create a distribution strategy for the dataset (single machine).
strategy = tf.distribute.experimental.CentralStorageStrategy()
train_dist_ds = strategy.experimental_distribute_dataset(train_dataset)
test_dist_ds = strategy.experimental_distribute_dataset(test_dataset)

tf.print("Configured data for training and evaluation!")
tf.print("  sample shape: %s" % str(x_train[0].shape))
tf.print("  training samples: %s" % x_train.shape[0])
tf.print("  test     samples: %s" % x_test.shape[0])
tf.print("  epochs: %s" % num_epochs)
tf.print("  steps/epoch: %s" % steps_per_epoch)
tf.print("  steps/eval : %s" % steps_per_eval)

tf.print("Sample image from the dataset:")
sample_image = x_test[SAMPLE_EXAMPLE_INDEX]
sample_image_batch = np.expand_dims(sample_image, axis=0)
sample_label = y_test[SAMPLE_EXAMPLE_INDEX]
plt.imshow(sample_image.reshape(IMG_ROWS, IMG_COLS))
tf.print("\nGround truth labels: %s" % str(sample_label))

Loaded MNIST dataset!
INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:CPU:0'], variable_device = '/job:localhost/replica:0/task:0/device:CPU:0'
Configured data for training and evaluation!
  sample shape: (28, 28, 1)
  training samples: 60000
  test     samples: 10000
  epochs: 5
  steps/epoch: 300
  steps/eval : 50

Sample image from the dataset:
Ground truth labels: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

In [5]:
#@title Define MNIST model architecture using tf.keras API

def simple_mnist_model(input_shape):
  """Creates a simple (multi-layer perceptron) MNIST model."""

  model = tf.keras.models.Sequential()
  # Flatten to a 1d array (e.g. 28x28 -> 784)
  # Fully-connected neural layer with 128 neurons, RELU activation
  model.add(tf.keras.layers.Dense(128, activation='relu'))
  # Fully-connected neural layer returning probability scores for each class
  model.add(tf.keras.layers.Dense(10, activation='softmax'))
  return model

In [6]:
#@title Train the Keras model

with strategy.scope():
  model = simple_mnist_model(input_shape)
  tf.print("Constructed Keras MNIST model, training...")

  optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)
  training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
  training_accuracy = tf.keras.metrics.CategoricalAccuracy(
      "training_accuracy", dtype=tf.float32)
  test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32)
  test_accuracy = tf.keras.metrics.CategoricalAccuracy(
      "test_accuracy", dtype=tf.float32)

  def train_step(iterator):
    """Training StepFn."""

    def step_fn(inputs):
      """Per-Replica StepFn."""
      images, labels = inputs
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        loss = tf.keras.losses.categorical_crossentropy(labels, logits)
        loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync
      grads = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      training_accuracy.update_state(labels, logits), args=(next(iterator),))

  def test_step(iterator):
    """Evaluation StepFn."""

    def step_fn(inputs):
      images, labels = inputs
      logits = model(images, training=False)
      loss = tf.keras.losses.categorical_crossentropy(labels, logits)
      loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync
      test_accuracy.update_state(labels, logits), args=(next(iterator),))

  for epoch in range(0, num_epochs):
    tf.print("Running epoch #%s" % (epoch + 1))

    train_iterator = iter(train_dist_ds)
    for step in range(steps_per_epoch):
    tf.print("  Training loss: %f, accuracy: %f" % (training_loss.result(), training_accuracy.result() * 100))

    test_iterator = iter(test_dist_ds)
    for step in range(steps_per_eval):
    tf.print("  Test loss    : %f, accuracy: %f" % (test_loss.result(), test_accuracy.result() * 100))

  tf.print("Completed training!")

  # Run a single prediction on the trained model
  tf_prediction = model(sample_image_batch, training=False)
  tf.print("Sample prediction:")
  tf.print(tf_prediction[0] * 100.0, summarize=100)

Constructed Keras MNIST model, training...
Running epoch #1
  Training loss: 0.732439, accuracy: 81.403336
  Test loss    : 0.390855, accuracy: 89.490005
Running epoch #2
  Training loss: 0.365308, accuracy: 89.811668
  Test loss    : 0.315630, accuracy: 91.119995
Running epoch #3
  Training loss: 0.312111, accuracy: 91.129997
  Test loss    : 0.281829, accuracy: 92.040001
Running epoch #4
  Training loss: 0.281028, accuracy: 92.038330
  Test loss    : 0.258432, accuracy: 92.629997
Running epoch #5
  Training loss: 0.257909, accuracy: 92.753334
  Test loss    : 0.240058, accuracy: 93.229996
Completed training!

Sample prediction:
[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]

In [7]:
#@title Export the trained model as a SavedModel, with IREE-compatible settings

# Since the model was written in sequential style, explicitly wrap in a module.
saved_model_dir = "/tmp/"
inference_module = tf.Module()
inference_module.model = model
# Hack: Convert to static shape. Won't be necessary once dynamic shapes are in.
dynamic_input_shape = list(model.inputs[0].shape)
dynamic_input_shape[0] = 1  # Make fixed (batch=1)
# Produce a concrete function.
inference_module.predict = tf.function(
        tf.TensorSpec(dynamic_input_shape, model.inputs[0].dtype)])(
            lambda x:, training=False))
save_options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.print("Exporting SavedModel to %s" % saved_model_dir), saved_model_dir, options=save_options)

Exporting SavedModel to /tmp/
WARNING:tensorflow:From c:\users\scott\scoop\apps\python\current\lib\site-packages\tensorflow\python\training\tracking\ Model.state_updates (from is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From c:\users\scott\scoop\apps\python\current\lib\site-packages\tensorflow\python\training\tracking\ Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/\assets

Compile and Execute MNIST Model using IREE

In [8]:
#@title Load the SavedModel into IREE's compiler as MLIR mhlo

compiler_module = ireec.tf_load_saved_model(
    saved_model_dir, exported_names=["predict"])
tf.print("Imported MLIR:\n", compiler_module.to_asm(large_element_limit=100))

# Write to a file for use outside of this notebook.
mnist_mlir_path = os.path.join(SAVE_PATH, "mnist.mlir")
with open(mnist_mlir_path, "wt") as output_file:
print("Wrote MLIR to path '%s'" % mnist_mlir_path)

Imported MLIR:

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 443 : i32}} {
  flow.variable @"__iree_flow___sm_node14__model.layer-1.kernel" opaque<"", "0xDEADBEEF"> : tensor<784x128xf32> attributes {sym_visibility = "private"}
  flow.variable @"__iree_flow___sm_node15__model.layer-1.bias" opaque<"", "0xDEADBEEF"> : tensor<128xf32> attributes {sym_visibility = "private"}
  flow.variable @"__iree_flow___sm_node20__model.layer-2.kernel" opaque<"", "0xDEADBEEF"> : tensor<128x10xf32> attributes {sym_visibility = "private"}
  flow.variable @"__iree_flow___sm_node21__model.layer-2.bias" dense<[-0.114143081, 0.0953421518, 4.84912744E-5, -0.0384164825, 0.0063888072, 0.218958765, 0.0256200824, 0.0551806651, -0.22108613, -0.0278935507]> : tensor<10xf32> attributes {sym_visibility = "private"}
  func @predict(%arg0: tensor<1x28x28x1xf32> {tf._user_specified_name = "x"}) -> tensor<1x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<1x28x28x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
    %0 = flow.variable.address @"__iree_flow___sm_node14__model.layer-1.kernel" : !iree.ptr<tensor<784x128xf32>>
    %1 = flow.variable.address @"__iree_flow___sm_node15__model.layer-1.bias" : !iree.ptr<tensor<128xf32>>
    %2 = flow.variable.address @"__iree_flow___sm_node20__model.layer-2.kernel" : !iree.ptr<tensor<128x10xf32>>
    %3 = flow.variable.address @"__iree_flow___sm_node21__model.layer-2.bias" : !iree.ptr<tensor<10xf32>>
    %4 = mhlo.constant dense<0xFF800000> : tensor<f32>
    %5 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = flow.variable.load.indirect %3 : !iree.ptr<tensor<10xf32>> -> tensor<10xf32>
    %7 = flow.variable.load.indirect %2 : !iree.ptr<tensor<128x10xf32>> -> tensor<128x10xf32>
    %8 = flow.variable.load.indirect %1 : !iree.ptr<tensor<128xf32>> -> tensor<128xf32>
    %9 = flow.variable.load.indirect %0 : !iree.ptr<tensor<784x128xf32>> -> tensor<784x128xf32>
    %10 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
    %11 = ""(%10, %9) : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>
    %12 = "mhlo.broadcast_in_dim"(%8) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
    %13 = mhlo.add %11, %12 : tensor<1x128xf32>
    %14 = "mhlo.broadcast_in_dim"(%5) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x128xf32>
    %15 = mhlo.maximum %14, %13 : tensor<1x128xf32>
    %16 = ""(%15, %7) : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>
    %17 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>) -> tensor<1x10xf32>
    %18 = mhlo.add %16, %17 : tensor<1x10xf32>
    %19 = "mhlo.reduce"(%18, %4) ( {
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
      %26 = mhlo.maximum %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%26) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
    %20 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>
    %21 = mhlo.subtract %18, %20 : tensor<1x10xf32>
    %22 = "mhlo.exponential"(%21) : (tensor<1x10xf32>) -> tensor<1x10xf32>
    %23 = "mhlo.reduce"(%22, %5) ( {
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
      %26 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%26) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
    %24 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>
    %25 = mhlo.divide %22, %24 : tensor<1x10xf32>
    return %25 : tensor<1x10xf32>
Wrote MLIR to path 'C:\Users\Scott\saved_models\mnist.mlir'

In [9]:
#@title Compile the mhlo MLIR and prepare a context to execute it

# Compile the MLIR module into a VM module for execution
flatbuffer_blob = compiler_module.compile(target_backends=[backend_name])
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)

# Register the module with a runtime context
config = ireert.Config(driver_name)
ctx = ireert.SystemContext(config=config)

Created IREE driver vulkan: <pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>
SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>

In [10]:
#@title Execute the compiled module and compare the results with TensorFlow

# Invoke the 'predict' function with a single image as an argument
iree_prediction = ctx.modules.module.predict(sample_image_batch)

tf.print("IREE prediction ('%s' backend, '%s' driver):" % (backend_name, driver_name))
tf.print(tf.convert_to_tensor(iree_prediction[0]) * 100.0, summarize=100)
tf.print("TensorFlow prediction:")
tf.print(tf_prediction[0] * 100.0, summarize=100)

IREE prediction ('vulkan-spirv' backend, 'vulkan' driver):
[0.243133873 0.00337268622 95.5214233 0.92537272 2.25061631e-05 0.992090821 2.20864058 3.87712225e-06 0.105901062 4.44369434e-05]

TensorFlow prediction:
[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]