Copyright © 2019 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

TFX Alternate Pipeline Architecture

This notebook demonstrates how to build and use TFX with a different pipeline architecture. Specifically, the pipeline in this example will branch following the feature engineering with Transform, so that two different Trainers each train and deploy two different model architectures from the same dataset. This illustrates one possible approach to doing A/B testing.

We will train an image classification model on the UC Merced Land Use Dataset of aerial pictures.

Setup

First, we install the necessary packages, download data, import modules and set up paths.

Install TFX and TensorFlow

Note

Because of some of the updates to packages you must use the button at the bottom of the output of this cell to restart the runtime. Following restart, you should rerun this cell.

Install TFX and Tensorflow


In [0]:
!pip install -q -U \
  tensorflow-gpu==2.0.0 \
  tfx==0.15.0rc0 \
  tensorflow-datasets

Import packages

We import necessary packages, including standard TFX component classes.


In [0]:
import os
import tempfile
import urllib

import matplotlib.pyplot as plt

import tensorflow as tf
keras = tf.keras
K = keras.backend

import tensorflow_data_validation as tfdv
import tensorflow_datasets as tfds
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.eval_saved_model.export import build_parsing_eval_input_receiver_fn

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.utils.dsl_utils import external_input

from tfx.proto import evaluator_pb2
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2

from tensorflow_metadata.proto.v0 import schema_pb2

In [0]:
tf.__version__

In [0]:
tfx.__version__

Download example data

We download the sample dataset for use in our TFX pipeline. We use TFDS to load the uc_merced dataset, for aerial image classification.


In [0]:
train_set, ds_info = tfds.load(name="uc_merced",
                               split="train",
                               as_supervised=True,
                               with_info=True)

In [0]:
ds_info

In [0]:
n_classes = ds_info.features['label'].num_classes
n_classes

In [0]:
class_names = ds_info.features['label'].names
class_names

In [0]:
num_rows, num_cols = 10, 5
plt.figure(figsize=(4 * num_cols, 4 * num_rows))
for index, (image, label) in enumerate(train_set.take(num_rows * num_cols)):
  plt.subplot(num_rows, num_cols, index + 1)
  plt.imshow(image)
  plt.title(class_names[label])
  plt.axis('off')
plt.show()

Note that a few images are slightly smaller than 256x256:


In [0]:
for img, label in train_set:
    if img.shape!=(256, 256, 3):
        print(img.shape)

Running the pipeline interactively


In [0]:
context = InteractiveContext()

In [0]:
HOME = os.path.expanduser('~')
examples_path = os.path.join(HOME, "tensorflow_datasets", "uc_merced", "0.0.1")
dataset = tf.data.TFRecordDataset(os.path.join(examples_path, "uc_merced-train.tfrecord-00000-of-00001"))
decoder = tfdv.TFExampleDecoder()
for tfrecord in dataset.take(1):
  example = decoder.decode(tfrecord.numpy())
  img = tf.io.decode_png(example['image'][0])

In [0]:
example

In [0]:
plt.imshow(img)
plt.axis('off')
plt.show()

In [0]:
img.shape

In [0]:
input_data = external_input(examples_path)

input_config = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='train', pattern='uc_merced-train*')])
#Or equivalently:
#input_config = tfx.components.example_gen.utils.make_default_input_config(
#    split_pattern='uc_merced-train*')

example_gen = ImportExampleGen(input=input_data, input_config=input_config)

context.run(example_gen)

In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen)

In [0]:
# Generates schema based on statistics files.
infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
context.run(infer_schema)

In [0]:
train_uri = infer_schema.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, "schema.pbtxt")
schema = tfx.utils.io_utils.parse_pbtxt_file(file_name=schema_filename,
                                             message=schema_pb2.Schema())

In [0]:
tfdv.display_schema(schema)

In [0]:
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=infer_schema.outputs['schema'])
context.run(validate_stats)

In [0]:
# Set up paths.
_transform_module_file = 'uc_merced_tranform.py'

In [0]:
%%writefile {_transform_module_file}

import tensorflow_transform as tft
import tensorflow as tf

LABEL_KEY = 'label'

def transformed_name(name):
  return name + '_xf'

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  for feature, value in inputs.items():
    outputs[transformed_name(feature)] = _fill_in_missing(value)
  return outputs

def _fill_in_missing(x):
  """Replace missing values in a SparseTensor.

  Fills in missing values of `x` with '' or 0, and converts to a dense tensor.

  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.

  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)

In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    module_file=_transform_module_file)
context.run(transform)

In [0]:
_trainer_module = 'uc_merced_trainer'
_trainer_module_file = _trainer_module + '.py'
_serving_model_dir = os.path.join(tempfile.mkdtemp(),
                                  'serving_model/uc_merced_simple')

In [0]:
%%writefile {_trainer_module_file}

import tensorflow as tf
keras = tf.keras
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

LABEL_KEY = 'label'
DROP_FEATURES = ["filename"]
NUM_CLASSES = 21

def transformed_name(name):
  return name + '_xf'


# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
  return schema_utils.schema_as_feature_spec(schema).feature_spec


def _gzip_reader_fn(filenames):
  """Small utility returning a record reader that can read gzip'ed files."""
  return tf.data.TFRecordDataset(
      filenames,
      compression_type='GZIP')


@tf.function
def decode_and_resize(image):
    return tf.image.resize(tf.io.decode_png(image), (256, 256))


@tf.function
def parse_png_images(png_images):
  with tf.device("/cpu:0"):
    flattened = tf.reshape(png_images, [-1])
    decoded = tf.map_fn(decode_and_resize, flattened, dtype=tf.float32)
    reshaped = tf.reshape(decoded, [-1, 256, 256, 3])
    return reshaped / 255.


def _build_estimator(config, num_filters=None):
  """Build an estimator for classifying uc_merced images

  Args:
    config: tf.estimator.RunConfig defining the runtime environment for the
      estimator (including model_dir).
    num_filters: [int], number of filters per Conv2D layer

  Returns:
    The estimator that will be used for training and eval.
  """
  model = keras.models.Sequential()
  model.add(keras.layers.InputLayer(input_shape=[1], dtype="string", name="image_xf"))
  model.add(keras.layers.Lambda(parse_png_images))
  for filters in num_filters:
      model.add(keras.layers.Conv2D(filters=filters, kernel_size=3, activation="relu"))
      model.add(keras.layers.MaxPool2D())
  model.add(keras.layers.Flatten())
  model.add(keras.layers.Dense(NUM_CLASSES, activation="softmax"))
  model.compile(loss="sparse_categorical_crossentropy",
                optimizer="adam", metrics=["accuracy"])
  return tf.keras.estimator.model_to_estimator(
      keras_model=model,
      config=config,
      custom_objects={"parse_png_images": parse_png_images})


def _example_serving_receiver_fn(tf_transform_output, schema):
  """Build the serving in inputs.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    Tensorflow graph which parses examples, applying tf-transform to them.
  """
  raw_feature_spec = _get_raw_feature_spec(schema)
  raw_feature_spec.pop(LABEL_KEY)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  transformed_features = tf_transform_output.transform_raw_features(
      serving_input_receiver.features)
  for feature in DROP_FEATURES + [LABEL_KEY]:
    transformed_features.pop(transformed_name(feature))

  return tf.estimator.export.ServingInputReceiver(
      transformed_features, serving_input_receiver.receiver_tensors)


def _eval_input_receiver_fn(tf_transform_output, schema):
  """Build everything needed for the tf-model-analysis to run the model.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    EvalInputReceiver function, which contains:
      - Tensorflow graph which parses raw untransformed features, applies the
        tf-transform preprocessing operators.
      - Set of raw, untransformed features.
      - Label against which predictions will be compared.
  """
  # Notice that the inputs are raw features, not transformed features here.
  raw_feature_spec = _get_raw_feature_spec(schema)

  
  serialized_tf_example = tf.compat.v1.placeholder(
      dtype=tf.string, shape=[None], name='input_example_tensor')

  # Add a parse_example operator to the tensorflow graph, which will parse
  # raw, untransformed, tf examples.
  features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)

  # Now that we have our raw examples, process them through the tf-transform
  # function computed during the preprocessing step.
  transformed_features = tf_transform_output.transform_raw_features(
      features)

  # The key name MUST be 'examples'.
  receiver_tensors = {'examples': serialized_tf_example}

  # NOTE: Model is driven by transformed features (since training works on the
  # materialized output of TFT, but slicing will happen on raw features).
  features.update(transformed_features)
  for feature in DROP_FEATURES + [LABEL_KEY]:
    if feature in features:
        features.pop(feature)
    if transformed_name(feature) in features:
        features.pop(transformed_name(feature))
  features.pop('image')
  return tfma.export.EvalInputReceiver(
      features=features,
      receiver_tensors=receiver_tensors,
      labels=transformed_features[transformed_name(LABEL_KEY)])


def _input_fn(filenames, tf_transform_output, batch_size=200):
  """Generates features and labels for training or evaluation.

  Args:
    filenames: [str] list of CSV files to read data from.
    tf_transform_output: A TFTransformOutput.
    batch_size: int First dimension size of the Tensors returned by input_fn

  Returns:
    A (features, indices) tuple where features is a dictionary of
      Tensors, and indices is a single Tensor of label indices.
  """
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())
  dataset = tf.data.experimental.make_batched_features_dataset(
      filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)

  transformed_features = dataset.make_one_shot_iterator().get_next()

  for feature in DROP_FEATURES:
    transformed_features.pop(transformed_name(feature))

  return transformed_features, transformed_features.pop(
      transformed_name(LABEL_KEY))


# TFX will call this function
def trainer_fn(hparams, schema):
  """Build the estimator using the high level API.
  Args:
    hparams: Holds hyperparameters used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.
  Returns:
    A dict of the following:
      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(hparams.transform_output)

  train_input_fn = lambda: _input_fn(
      hparams.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(
      hparams.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn,
      max_steps=hparams.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('uc-merced', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=hparams.eval_steps,
      exporters=[exporter],
      name='uc-merced-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=hparams.serving_model_dir)

  num_filters = [hparams.first_cnn_filters]
  for layer_index in range(1, hparams.num_cnn_layers):
    num_filters.append(num_filters[-1] * 2)

  estimator = _build_estimator(
      config=run_config,
      num_filters=num_filters)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }

def trainer_fn1(hparams, schema):
  hparams.first_cnn_filters = 32
  hparams.num_cnn_layers = 4
  return trainer_fn(hparams, schema)

def trainer_fn2(hparams, schema):
  hparams.first_cnn_filters = 16
  hparams.num_cnn_layers = 5
  return trainer_fn(hparams, schema)

In [0]:
# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
    trainer_fn="{}.trainer_fn1".format(_trainer_module),
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=200),
    eval_args=trainer_pb2.EvalArgs(num_steps=100))
context.run(trainer)

In [0]:
# Uses TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
        evaluator_pb2.SingleSlicingSpec(
            column_for_slicing=['label_xf'])
    ]))
context.run(model_analyzer)

In [0]:
evaluation_uri = model_analyzer.outputs['output'].get()[0].uri
eval_result = tfma.load_eval_result(evaluation_uri)
eval_result

In [0]:
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
    examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
context.run(model_validator)

In [0]:
blessing_uri = model_validator.outputs['blessing'].get()[0].uri
!ls -l {blessing_uri}

In [0]:
# Setup serving path
_serving_model_dir = os.path.join(tempfile.mkdtemp(),
                                  'serving_model/uc_merced_simple')

In [0]:
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=model_validator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)

Create the pipeline using Beam orchestration


In [0]:
_pipeline_name = 'uc_merced_beam'

_pipeline_root = tempfile.mkdtemp(prefix='tfx-pipelines')
_pipeline_root = os.path.join(_pipeline_root, 'pipelines', _pipeline_name)

# Sqlite ML-metadata db path.
_metadata_root = tempfile.mkdtemp(prefix='tfx-metadata')
_metadata_path = os.path.join(_metadata_root, 'metadata.db')

In [0]:
def _create_pipeline(pipeline_name, pipeline_root, data_root,
                     transform_module_file, trainer_module_file,
                     serving_model_dir, metadata_path):
  """Implements the UC Merced classification pipeline with TFX."""
  input_data = external_input(data_root)

  input_config = example_gen_pb2.Input(splits=[
      example_gen_pb2.Input.Split(name='train', pattern='uc_merced-train*')])
  #Or equivalently:
  #input_config = tfx.components.example_gen.utils.make_default_input_config(
  #    split_pattern='uc_merced-train*')

  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = ImportExampleGen(input=input_data, input_config=input_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  infer_schema = SchemaGen(
      statistics=statistics_gen.outputs['statistics'])

  # Performs anomaly detection based on statistics and data schema.
  validate_stats = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=infer_schema.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=infer_schema.outputs['schema'],
      module_file=transform_module_file)

  components = [example_gen, statistics_gen, infer_schema,
                validate_stats, transform]

  for index in (1, 2):
    # Uses user-provided Python function that implements a model using
    # TensorFlow's Estimators API.
    trainer = Trainer(
        trainer_fn='{}.trainer_fn{}'.format(_trainer_module, index),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=infer_schema.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=200),
        eval_args=trainer_pb2.EvalArgs(num_steps=100),
        instance_name='Trainer{}'.format(index))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    model_analyzer = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
            evaluator_pb2.SingleSlicingSpec(
                column_for_slicing=['label_xf'])
        ]),
        instance_name="Evaluator{}".format(index))

    # Performs quality validation of a candidate model (compared to a baseline).
    model_validator = ModelValidator(
        examples=example_gen.outputs['examples'], model=trainer.outputs['model'],
        instance_name="ModelValidator{}".format(index))

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(
        model=trainer.outputs['model'],
        model_blessing=model_validator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=serving_model_dir)),
        instance_name='Pusher{}'.format(index))

    components += [trainer, model_analyzer, model_validator, pusher]

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      additional_pipeline_args={},
  )

In [0]:
uc_merced_pipeline = _create_pipeline(
        pipeline_name=_pipeline_name,
        pipeline_root=_pipeline_root,
        data_root=examples_path,
        transform_module_file=_transform_module_file,
        trainer_module_file=_trainer_module_file,
        serving_model_dir=_serving_model_dir,
        metadata_path=_metadata_path)

In [0]:
BeamDagRunner().run(uc_merced_pipeline)

In [0]: