Defines a simple TF module, saves it and loads it in IREE.

Start kernel:

  • Install a TensorFlow2 nightly pip (or bring your own)
  • Enable IREE/TF integration by adding to your user.bazelrc: build --define=iree_tensorflow=true
  • Optional: Prime the build: bazel build bindings/python/pyiree
  • Start colab by running python colab/start_colab_kernel.py (see that file for initial setup instructions)

TODO:

  • This is just using low-level binding classes. Change to high level API.
  • Plumg through ability to run TF compiler lowering passes and import directly into IREE

In [0]:
import os
import tensorflow as tf
from pyiree.tf import compiler as ireec

SAVE_PATH = os.path.join(os.environ["HOME"], "saved_models")
os.makedirs(SAVE_PATH, exist_ok=True)

In [15]:
class MyModule(tf.Module):
  def __init__(self):
    self.v = tf.Variable([4], dtype=tf.float32)
  
  @tf.function(
      input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]
  )
  def add(self, a, b):
    return tf.tanh(self.v * a + b)

my_mod = MyModule()

saved_model_path = os.path.join(SAVE_PATH, "simple.sm")

options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(my_mod, saved_model_path, options=options)

input_module = ireec.tf_load_saved_model(saved_model_path, pass_pipeline=[])
print('LOADED ASM:', input_module.to_asm())

# Canonicalize the TF import.
input_module.run_pass_pipeline([
  "tf-executor-graph-pruning",
  "tf-standard-pipeline",
  "canonicalize",
])
print("LOWERED TF ASM:", input_module.to_asm())

# Legalize to XLA (high-level).
input_module.run_pass_pipeline([
  "xla-legalize-tf{allow-partial-conversion=true}",
])
print("XLA ASM:", input_module.to_asm())


INFO:tensorflow:Assets written to: /usr/local/google/home/scotttodd/saved_models/simple.sm/assets
LOADED ASM: 

module attributes {tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$dim { size: 4 }", "tfshape$dim { size: 4 }", "tfshape$unknown_rank: true"], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {_output_shapes = ["tfshape$dim { size: 1 }"], device = "", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>
      %outputs_0, %control_1 = tf_executor.island wraps "tf.Mul"(%outputs, %arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>
      %outputs_2, %control_3 = tf_executor.island wraps "tf.AddV2"(%outputs_0, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
      %outputs_4, %control_5 = tf_executor.island wraps "tf.Tanh"(%outputs_2) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>) -> tensor<4xf32>
      %outputs_6, %control_7 = tf_executor.island(%control) wraps "tf.Identity"(%outputs_4) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>) -> tensor<4xf32>
      tf_executor.fetch %outputs_6, %control : tensor<4xf32>, !tf_executor.control
    }
    return %0 : tensor<4xf32>
  }
}

LOWERED TF ASM: 

module attributes {tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$dim { size: 4 }", "tfshape$dim { size: 4 }", "tfshape$unknown_rank: true"], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = "tf.ReadVariableOp"(%arg2) {_output_shapes = ["tfshape$dim { size: 1 }"], device = "", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>
    %1 = "tf.Mul"(%0, %arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>
    %2 = "tf.AddV2"(%1, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
    %3 = "tf.Tanh"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>) -> tensor<4xf32>
    %4 = "tf.Identity"(%3) {T = f32, _output_shapes = ["tfshape$dim { size: 4 }"], device = ""} : (tensor<4xf32>) -> tensor<4xf32>
    return %4 : tensor<4xf32>
  }
}

XLA ASM: 

module attributes {tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$dim { size: 4 }", "tfshape$dim { size: 4 }", "tfshape$unknown_rank: true"], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = "tf.ReadVariableOp"(%arg2) {_output_shapes = ["tfshape$dim { size: 1 }"], device = "", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>
    %1 = "mhlo.multiply"(%0, %arg0) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>
    %2 = mhlo.add %1, %arg1 : tensor<4xf32>
    %3 = "mhlo.tanh"(%2) : (tensor<4xf32>) -> tensor<4xf32>
    return %3 : tensor<4xf32>
  }
}