Copyright © 2019 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

TFX — Running a simple pipeline manually in a Colab Notebook

Running a simple pipeline manually in a Colab Notebook

This notebook demonstrates how to use Jupyter/Colab notebooks for TFX iterative development. Here, we walk through the Online News popularity dataset in an interactive notebook.

Working in an interactive notebook is a useful way to become familiar with the structure of a TFX pipeline. It's also useful when doing development of your own pipelines as a lightweight development environment, but you should be aware that there are differences in the way interactive notebooks are orchestrated, and how they access metadata artifacts.

Orchestration

In a production deployment of TFX you will use an orchestrator such as Apache Airflow, Kubeflow, or Apache Beam. In an interactive notebook the notebook itself is the orchestrator, running each TFX component as you execute the notebook cells.

Metadata

In a production deployment of TFX you will access metadata through the ML Metadata (MLMD) API. MLMD stores metadata properties in a database such as MySQL, and stores the metadata payloads in a persistent store such as on your filesystem. In an interactive notebook, both properties and payloads are stored in the /tmp directory on the Jupyter notebook or Colab server.

Setup

First, we install the necessary packages, download data, import modules and set up paths.

Install TFX and TensorFlow

Note

Because of some of the updates to packages you must use the button at the bottom of the output of this cell to restart the runtime. Following restart, you should rerun this cell.


In [0]:
!pip install -q -U \
  tensorflow==2.0.0 \
  tfx==0.15.0rc0

Import packages

We import necessary packages, including standard TFX component classes.


In [0]:
import base64
import csv
import json
import os
import requests
import tempfile
import urllib
import pprint
pp = pprint.PrettyPrinter()

import tensorflow as tf

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import external_input

from tensorflow_metadata.proto.v0 import anomalies_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2

import tensorflow_transform as tft
from tensorflow_transform import coders as tft_coders
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_transform.tf_metadata import schema_utils

import tensorflow_model_analysis as tfma
import tensorflow_data_validation as tfdv

Check the versions


In [0]:
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))

Download example data

We download the sample dataset for use in our TFX pipeline. We're working with a variant of the Online News Popularity dataset, which summarizes a heterogeneous set of features about articles published by Mashable in a period of two years. The goal is to predict how popular the article will be on social networks. Specifically, in the original dataset the objective was to predict the number of times each article will be shared on social networks. In this variant, the goal is to predict the article's popularity percentile. For example, if the model predicts a score of 0.7, then it means it expects the article to be shared more than 70% of all articles.


In [0]:
# Download the example data.
DATA_PATH = 'https://raw.githubusercontent.com/ageron/open-datasets/master/' \
   'online_news_popularity_for_course/online_news_popularity_for_course.csv'
_data_root = tempfile.mkdtemp(prefix='tfx-data')
_data_filepath = os.path.join(_data_root, "data.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)

Take a quick look at the CSV file.


In [0]:
!head {_data_filepath}

Create the InteractiveContext

An interactive context is used to provide global context when running a TFX pipeline in a notebook without using a runner or orchestrator such as Apache Airflow or Kubeflow. This style of development is only useful when developing the code for a pipeline, and cannot currently be used to deploy a working pipeline to production.


In [0]:
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext.
context = InteractiveContext()

Run TFX Components Interactively


In the cells that follow you will construct TFX components and run each one interactively within the InteractiveContext to obtain ExecutionResult objects. This mirrors the process of an orchestrator running components in a TFX DAG based on when the dependencies for each component are met.

The ExampleGen Component

In any ML development process the first step when starting code development is to ingest the training and test datasets. The ExampleGen component brings data into the TFX pipeline.

Let's create an ExampleGen component and run it.

Exercise 1 — Creating and Running Your First Component

  1. Use the external_input() function to create an input Channel for the ExampleGen component.
  2. Create an instance of the CsvExampleGen class, passing it the input channel you just built.
  3. Call the InteractiveContext's run() method, passing it the ExampleGen instance.

Hint: if you're wondering about some TFX function or class, you can check out TFX's API documentation, or just type its name followed by a question mark (e.g., CsvExampleGen?) and run the cell.


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Use the packaged CSV input data.
input_data = external_input(_data_root)

# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=input_data)
context.run(example_gen)

The component's outputs include 2 artifacts: the training examples and the eval examples (by default, split 2/3 training, 1/3 eval):


In [0]:
for artifact in example_gen.outputs['examples'].get():
  print(artifact.split, artifact.uri)

Take a peek at the output training examples to see what they look like.

  1. Get the URI of the output artifact representing the training examples, which is a directory
  2. Get the list of files in this directory (all compressed TFRecord files), and create a TFRecordDataset to read these files
  3. Iterate over the first record and decode it using a TFExampleDecoder to check the results

In [0]:
train_uri = example_gen.outputs['examples'].get()[0].uri
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
decoder = tfdv.TFExampleDecoder()
for tfrecord in dataset.take(1):
  serialized_example = tfrecord.numpy()
  example = decoder.decode(serialized_example)
  pp.pprint(example)

The StatisticsGen Component

The StatisticsGen component computes descriptive statistics for your dataset. The statistics that it generates can be visualized for review, and are used for example validation and to infer a schema.

Create a StatisticsGen component and run it.

Exercise 2 — Computing Statistics

  1. Create an instance of the StatisticsGen class, passing it the examples channel that was output by the CsvExampleGen component.
  2. Use the InteractiveContext to run this component.

Hint: example_gen.outputs will return a dictionary containing all the outputs of the example_gen component. Examine its keys: you will find that it only has one "examples" key. The corresponding value is the Channel object that you need to pass to the constructor of the StatisticsGen class.


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen)

Again, let's take a peek at the output training artifact. Note that this time it is a TFRecord file containing a single record with a serialized DatasetFeatureStatisticsList protobuf:


In [0]:
train_uri = statistics_gen.outputs['statistics'].get()[0].uri
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]
dataset = tf.data.TFRecordDataset(tfrecord_filenames)
for tfrecord in dataset.take(1):
  serialized_example = tfrecord.numpy()
  stats = statistics_pb2.DatasetFeatureStatisticsList()
  stats.ParseFromString(serialized_example)

The stats can be visualized using the tfdv.visualize_statistics() function (we will look at this in more detail in a subsequent lab).


In [0]:
tfdv.visualize_statistics(stats)

The SchemaGen Component

The SchemaGen component generates a schema for your data based on the statistics from StatisticsGen. It tries to infer the data types of each of your features, and the ranges of legal values for categorical features.

Create a SchemaGen component and run it.

Exercise 3 — Inferring the Schema

  1. Create an instance of the SchemaGen class, passing it the statistics channel that was output by the StatisticsGen component.
  2. Use the InteractiveContext to run this component.

In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:
# Generates schema based on statistics files.
infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
context.run(infer_schema)

The generated artifact is just a schema.pbtxt containing a text representation of a schema_pb2.Schema protobuf:


In [0]:
train_uri = infer_schema.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, "schema.pbtxt")
schema = tfx.utils.io_utils.parse_pbtxt_file(file_name=schema_filename,
                                             message=schema_pb2.Schema())

It can be visualized using tfdv.display_schema() (we will look at this in more detail in a subsequent lab):


In [0]:
tfdv.display_schema(schema)

The ExampleValidator Component

The ExampleValidator performs anomaly detection, based on the statistics from StatisticsGen and the schema from SchemaGen. It looks for problems such as missing values, values of the wrong type, or categorical values outside of the domain of acceptable values.

Create an ExampleValidator component and run it.

Exercise 4 — Validating the Examples

  1. Create an instance of the ExampleValidator class. This time you need to pass the constructor two channels: the statistics channel from the StatisticsGen component, and the schema channel from the SchemaGen component.
  2. Use the InteractiveContext to run this component.

In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=infer_schema.outputs['schema'])
context.run(validate_stats)

The output artifact of the ExampleValidator is an anomalies.pbtxt file describing an anomalies_pb2.Anomalies protobuf:


In [0]:
train_uri = validate_stats.outputs['anomalies'].get()[0].uri
anomalies_filename = os.path.join(train_uri, "anomalies.pbtxt")
anomalies = tfx.utils.io_utils.parse_pbtxt_file(
    file_name=anomalies_filename,
    message=anomalies_pb2.Anomalies())

This can be visualized using the tfdv.display_anomalies() function (we will look at this in more details in a subsequent lab). Did it find any anomalies?


In [0]:
tfdv.display_anomalies(anomalies)

The Transform Component

The Transform component performs data transformations and feature engineering. The results include an input TensorFlow graph which is used during both training and serving to preprocess the data before training or inference. This graph becomes part of the SavedModel that is the result of model training. Since the same input graph is used for both training and serving, the preprocessing will always be the same, and only needs to be written once.

The Transform component requires more code than many other components because of the arbitrary complexity of the feature engineering that you may need for the data and/or model that you're working with. It requires code files to be available which define the processing needed.

Define some constants and functions for both the Transform component and the Trainer component. Define them in a Python module, in this case saved to disk using the %%writefile magic command since you are working in a notebook.


In [0]:
_constants_module_file = 'online_news_constants.py'

In [0]:
%%writefile {_constants_module_file}

DENSE_FLOAT_FEATURE_KEYS = [
    "timedelta", "n_tokens_title", "n_tokens_content",
    "n_unique_tokens", "n_non_stop_words", "n_non_stop_unique_tokens",
    "n_hrefs", "n_self_hrefs", "n_imgs", "n_videos", "average_token_length",
    "n_keywords", "kw_min_min", "kw_max_min", "kw_avg_min", "kw_min_max",
    "kw_max_max", "kw_avg_max", "kw_min_avg", "kw_max_avg", "kw_avg_avg",
    "self_reference_min_shares", "self_reference_max_shares",
    "self_reference_avg_shares", "is_weekend", "global_subjectivity",
    "global_sentiment_polarity", "global_rate_positive_words",
    "global_rate_negative_words", "rate_positive_words", "rate_negative_words",
    "avg_positive_polarity", "min_positive_polarity", "max_positive_polarity",
    "avg_negative_polarity", "min_negative_polarity", "max_negative_polarity",
    "title_subjectivity", "title_sentiment_polarity", "abs_title_subjectivity",
    "abs_title_sentiment_polarity"]

VOCAB_FEATURE_KEYS = ["data_channel"]

BUCKET_FEATURE_KEYS = ["LDA_00", "LDA_01", "LDA_02", "LDA_03", "LDA_04"]

CATEGORICAL_FEATURE_KEYS = ["weekday"]

# Categorical features are assumed to each have a maximum value in the dataset.
MAX_CATEGORICAL_FEATURE_VALUES = [6]

#UNUSED: date, slug

LABEL_KEY = "n_shares_percentile"
VOCAB_SIZE = 10
OOV_SIZE = 5
FEATURE_BUCKET_COUNT = 10

def transformed_name(key):
  return key + '_xf'

Now let's define a module containing the preprocessing_fn() function that we will pass to the Transform component:


In [0]:
_transform_module_file = 'online_news_transform.py'

In [0]:
%%writefile {_transform_module_file}

import tensorflow as tf

import tensorflow_transform as tft
from online_news_constants import *

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  for key in DENSE_FLOAT_FEATURE_KEYS:
    # Preserve this feature as a dense float, setting nan's to the mean.
    outputs[transformed_name(key)] = tft.scale_to_z_score(
        _fill_in_missing(inputs[key]))

  for key in VOCAB_FEATURE_KEYS:
    # Build a vocabulary for this feature.
    outputs[transformed_name(key)] = tft.compute_and_apply_vocabulary(
        _fill_in_missing(inputs[key]),
        top_k=VOCAB_SIZE,
        num_oov_buckets=OOV_SIZE)

  for key in BUCKET_FEATURE_KEYS:
    outputs[transformed_name(key)] = tft.bucketize(
        _fill_in_missing(inputs[key]), FEATURE_BUCKET_COUNT,
        always_return_num_quantiles=False)

  for key in CATEGORICAL_FEATURE_KEYS:
    outputs[transformed_name(key)] = _fill_in_missing(inputs[key])

  # How popular is this article?
  outputs[transformed_name(LABEL_KEY)] = _fill_in_missing(inputs[LABEL_KEY])

  return outputs

def _fill_in_missing(x):
  """Replace missing values in a SparseTensor.

  Fills in missing values of `x` with '' or 0, and converts to a dense tensor.

  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.

  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)

Create and run the Transform component, referring to the files that were created above.

Exercise 5 — Transforming the Input Features

  1. Create an instance of the Transform class. You will need to pass it the examples channel, the schema channel and the name of the transform module file.
  2. Use the InteractiveContext to run this component.

Hint: run Transform? to find out the argument names you need to set.


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    module_file=_transform_module_file)
context.run(transform)

The Transform component has 2 types of outputs:

  • transform_graph is the graph that can perform the preprocessing operations (this graph will be included in the serving and evaluation models).
  • transformed_examples represents the preprocessed training and evaluation data.

In [0]:
transform.outputs

Take a peek at the transform_graph artifact: it points to a directory containing 3 subdirectories:


In [0]:
train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)

The transform_fn subdirectory contains the actual preprocessing graph. The metadata subdirectory contains the schema of the original data. The transformed_metadata subdirectory contains the schema of the preprocessed data.

Take a look at some of the transformed examples and check that they are indeed processed as intended.


In [0]:
train_uri = transform.outputs['transformed_examples'].get()[1].uri
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
decoder = tfdv.TFExampleDecoder()
for tfrecord in dataset.take(3):
  serialized_example = tfrecord.numpy()
  example = decoder.decode(serialized_example)
  pp.pprint(example)

The Trainer Component

The Trainer component trains models using TensorFlow.

Create a Python module containing a trainer_fn function, which must return an estimator. If you prefer creating a Keras model, you can do so and then convert it to an estimator using keras.model_to_estimator().


In [0]:
# Setup paths.
_trainer_module_file = 'online_news_trainer.py'

In [0]:
%%writefile {_trainer_module_file}

import tensorflow as tf

import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

from online_news_constants import *


def transformed_names(keys):
  return [transformed_name(key) for key in keys]


# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
  return schema_utils.schema_as_feature_spec(schema).feature_spec


def _gzip_reader_fn(filenames):
  """Small utility returning a record reader that can read gzip'ed files."""
  return tf.data.TFRecordDataset(
      filenames,
      compression_type='GZIP')


def _build_estimator(config, hidden_units=None, warm_start_from=None):
  """Build an estimator for predicting the popularity of online news articles

  Args:
    config: tf.estimator.RunConfig defining the runtime environment for the
      estimator (including model_dir).
    hidden_units: [int], the layer sizes of the DNN (input layer first)
    warm_start_from: Optional directory to warm start from.

  Returns:
    The estimator that will be used for training and eval.
  """
  real_valued_columns = [
      tf.feature_column.numeric_column(key, shape=())
      for key in transformed_names(DENSE_FLOAT_FEATURE_KEYS)
  ]
  categorical_columns = [
      tf.feature_column.categorical_column_with_identity(
          key, num_buckets=VOCAB_SIZE + OOV_SIZE, default_value=0)
      for key in transformed_names(VOCAB_FEATURE_KEYS)
  ]
  categorical_columns += [
      tf.feature_column.categorical_column_with_identity(
          key, num_buckets=FEATURE_BUCKET_COUNT, default_value=0)
      for key in transformed_names(BUCKET_FEATURE_KEYS)
  ]
  categorical_columns += [
      tf.feature_column.categorical_column_with_identity(
          key,
          num_buckets=num_buckets,
          default_value=0) for key, num_buckets in zip(
              transformed_names(CATEGORICAL_FEATURE_KEYS),
              MAX_CATEGORICAL_FEATURE_VALUES)
  ]
  return tf.estimator.DNNLinearCombinedRegressor(
      config=config,
      linear_feature_columns=categorical_columns,
      dnn_feature_columns=real_valued_columns,
      dnn_hidden_units=hidden_units or [100, 70, 50, 25],
      warm_start_from=warm_start_from)


def _example_serving_receiver_fn(tf_transform_output, schema):
  """Build the serving in inputs.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    Tensorflow graph which parses examples, applying tf-transform to them.
  """
  raw_feature_spec = _get_raw_feature_spec(schema)
  raw_feature_spec.pop(LABEL_KEY)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  transformed_features = tf_transform_output.transform_raw_features(
      serving_input_receiver.features)

  return tf.estimator.export.ServingInputReceiver(
      transformed_features, serving_input_receiver.receiver_tensors)


def _eval_input_receiver_fn(tf_transform_output, schema):
  """Build everything needed for the tf-model-analysis to run the model.

  Args:
    tf_transform_output: A TFTransformOutput.
    schema: the schema of the input data.

  Returns:
    EvalInputReceiver function, which contains:
      - Tensorflow graph which parses raw untransformed features, applies the
        tf-transform preprocessing operators.
      - Set of raw, untransformed features.
      - Label against which predictions will be compared.
  """
  # Notice that the inputs are raw features, not transformed features here.
  raw_feature_spec = _get_raw_feature_spec(schema)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  features = serving_input_receiver.features.copy()
  transformed_features = tf_transform_output.transform_raw_features(features)
  
  # NOTE: Model is driven by transformed features (since training works on the
  # materialized output of TFT, but slicing will happen on raw features.
  features.update(transformed_features)

  return tfma.export.EvalInputReceiver(
      features=features,
      receiver_tensors=serving_input_receiver.receiver_tensors,
      labels=transformed_features[transformed_name(LABEL_KEY)])


def _input_fn(filenames, tf_transform_output, batch_size=200):
  """Generates features and labels for training or evaluation.

  Args:
    filenames: [str] list of CSV files to read data from.
    tf_transform_output: A TFTransformOutput.
    batch_size: int First dimension size of the Tensors returned by input_fn

  Returns:
    A (features, indices) tuple where features is a dictionary of
      Tensors, and indices is a single Tensor of label indices.
  """
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())

  dataset = tf.data.experimental.make_batched_features_dataset(
      filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)

  transformed_features = dataset.make_one_shot_iterator().get_next()
  # We pop the label because we do not want to use it as a feature while we're
  # training.
  return transformed_features, transformed_features.pop(
      transformed_name(LABEL_KEY))


# TFX will call this function
def trainer_fn(hparams, schema):
  """Build the estimator using the high level API.
  Args:
    hparams: Holds hyperparameters used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.
  Returns:
    A dict of the following:
      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(hparams.transform_output)

  train_input_fn = lambda: _input_fn(
      hparams.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(
      hparams.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn,
      max_steps=hparams.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('online-news', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=hparams.eval_steps,
      exporters=[exporter],
      name='online-news-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=hparams.serving_model_dir)

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=hparams.warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }

Create and run the Trainer component, passing it the file that we created above.

Exercise 6 — Training and Evaluating the Model

  1. Create an instance of the Trainer class. This component needs a lot of arguments:
    • the name of the module file
    • the transformed examples channel
    • the schema channel
    • the transform graph channel
    • training arguments: create an instance of the trainer_pb2.TrainArgs protobuf, specifying num_steps=10000
    • evaluation arguments: create an instance of the trainer_pb2.EvalArgs protobuf, specifying num_steps=5000.
  2. Use the InteractiveContext to run this component.

Hint: again, run Trainer? to find out the argument names you need to set.


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
    module_file=_trainer_module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)

Take a peek at the trained model which was exported from Trainer.


In [0]:
train_uri = trainer.outputs['model'].get()[0].uri
serving_model_path = os.path.join(train_uri, 'serving_model_dir', 'export', 'online-news')
latest_serving_model_path = os.path.join(serving_model_path, max(os.listdir(serving_model_path)))
exported_model = tf.saved_model.load(latest_serving_model_path)

In [0]:
exported_model.graph.get_operations()[:10] + ["..."]

Analyze Training with TensorBoard

Use TensorBoard to analyze the model training that was done in Trainer, and see how well our model trained.


In [0]:
%load_ext tensorboard

In [0]:
%tensorboard --logdir {os.path.join(train_uri, 'serving_model_dir')}

The Evaluator Component

The Evaluator component analyzes model performance using the TensorFlow Model Analysis library. It runs inference requests on particular subsets of the test dataset, based on which slices are defined by the developer. Knowing which slices should be analyzed requires domain knowledge of what is important in this particular use case or domain. The slice chosen for this example is weekday.

Create and run an Evaluator component.

Exercise 7 — Evaluate the Model with TFMA

  1. Create an instance of the Evaluator class. Try to figure out on your own what input channels this component needs by running Evaluator?.
  2. Use the InteractiveContext to run this component.

Hint: the constructor's arguments includes both the input channels and the optional output channels. The latter are usually not passed.


In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
)
context.run(model_analyzer)

Let's load the Evaluator results and render them using the tfma.view.render_slicing_metrics() function:


In [0]:
evaluation_uri = model_analyzer.outputs['output'].get()[0].uri
eval_result = tfma.load_eval_result(evaluation_uri)
tfma.view.render_slicing_metrics(eval_result)

We can also pass feature slice specifications if we want to evaluate the quality of the model over specific subsets of the data:


In [0]:
# Uses TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
        evaluator_pb2.SingleSlicingSpec(
            column_for_slicing=['weekday'])
    ]))
context.run(model_analyzer)

Let's look at the results:


In [0]:
evaluation_uri = model_analyzer.outputs['output'].get()[0].uri
eval_result = tfma.load_eval_result(evaluation_uri)
tfma.view.render_slicing_metrics(
      eval_result,
      slicing_spec=tfma.slicer.SingleSliceSpec(columns=['weekday']))

The metrics are also accessible programmatically:


In [0]:
for metric in eval_result.slicing_metrics:
  pp.pprint(metric)

The ModelValidator Component

The ModelValidator component performs validation of your candidate model compared to the previously deployed model (if any) using criteria that you define, or to a baseline value. If the new model scores better than the previous model it will be "blessed" by ModelValidator, approving it for deployment.

Exercise 8 — Validate the Model

  1. Create an instance of the ModelValidator class. Try to figure out on your own what input channels this component needs.
  2. Use the InteractiveContext to run this component.

In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'])
context.run(model_validator)

Examine the output of ModelValidator.


In [0]:
model_validator.outputs

In [0]:
blessing_uri = model_validator.outputs.blessing.get()[0].uri
!ls -l {blessing_uri}

The Pusher Component

The Pusher component checks whether a model has been "blessed", and if so, deploys it to production by pushing the model to a well known file destination.


In [0]:
# Setup serving path
_serving_model_dir = os.path.join(tempfile.mkdtemp(),
                                  'serving_model/online_news_simple')

Create and run a Pusher component.

Exercise 9 — Push the Model to Production

  1. Create an instance of the Pusher class. Try to figure out on your own what input channels this component needs. You will also need to pass the push destination like this:
    push_destination=pusher_pb2.PushDestination(
     filesystem=pusher_pb2.PushDestination.Filesystem(
             base_directory=_serving_model_dir))
    
  2. Use the InteractiveContext to run this component.

In [0]:


In [0]:


In [0]:


In [0]:


In [0]:

Solution:


In [0]:
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=model_validator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)

Examine the output of Pusher.


In [0]:
pusher.outputs

In [0]:
push_uri = pusher.outputs['pushed_model'].get()[0].uri
latest_version = max(os.listdir(push_uri))
latest_version_path = os.path.join(push_uri, latest_version)
model = tf.saved_model.load(latest_version_path)

Review the model signatures and methods.


In [0]:
for item in model.signatures.items():
  pp.pprint(item)

Alternartively, we can use the command line utility saved_model_cli to look at the MetaGraphDefs (the models) and SignatureDefs (the methods you can call) in our SavedModel. See this discussion of the SavedModel CLI in the TensorFlow Guide.


In [0]:
latest_pushed_model = os.path.join(_serving_model_dir, max(os.listdir(_serving_model_dir)))
!saved_model_cli show --dir {latest_pushed_model} --all

That tells us a few important things about our model. In this case we just trained our model, so we already know the inputs and outputs, but if we didn't this would be important information.

TensorFlow Serving

Now that we have a trained model that has been blessed by ModelValidator, and pushed to our deployment target by Pusher, we can load it into TensorFlow Serving and start serving inference requests.

Add TensorFlow Serving distribution URI as a package source

We're preparing to install TensorFlow Serving using Aptitude since this Colab runs in a Debian environment. We'll add the tensorflow-model-server package to the list of packages that Aptitude knows about. Note that we're running as root.

Note: This example is running TensorFlow Serving natively, but you can also run it in a Docker container, which is one of the easiest ways to get started using TensorFlow Serving.


In [0]:
# This is the same as you would do from your command line, but without the [arch=amd64], and no sudo
# You would instead do:
# echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
# curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

!echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
!apt update

Install TensorFlow Serving

This is all you need - one command line! Please note that running TensorFlow Serving in a Docker Container is also a great option, with a lot of advantages.


In [0]:
!apt-get install tensorflow-model-server

Start running TensorFlow Serving

This is where we start running TensorFlow Serving and load our model. After it loads we can start making inference requests using REST. There are some important parameters:

  • rest_api_port: The port that you'll use for REST requests.
  • model_name: You'll use this in the URL of REST requests. It can be anything.
  • model_base_path: This is the path to the directory where you've saved your model. Note that this base_path should not include the model version directory, which is why we split it off below.

In [0]:
os.environ["MODEL_DIR"] = os.path.split(latest_pushed_model)[0]

In [0]:
%%bash --bg 
nohup tensorflow_model_server \
  --rest_api_port=8501 \
  --model_name=online_news_simple \
  --model_base_path="${MODEL_DIR}" >server.log 2>&1

In [0]:
!tail server.log

Perform Inference on example data

Let's load some examples from the eval dataset, remove their labels (as the serving model does not expect labels) and send them to Tensorflow Serving through a single REST API call. Note that this will include the labels, but the server will ignore them.


In [0]:
eval_uri = example_gen.outputs['examples'].get()[1].uri
eval_tfrecord_paths = [os.path.join(eval_uri, name)
                      for name in os.listdir(eval_uri)]

In [0]:
def strip_label(serialized_example):
  example = tf.train.Example.FromString(serialized_example.numpy())
  del example.features.feature["n_shares_percentile"]
  return example.SerializeToString()

dataset = tf.data.TFRecordDataset(eval_tfrecord_paths,
                                  compression_type="GZIP")
serialized_examples = [strip_label(serialized_example)
                       for serialized_example in dataset.take(3)]

In [0]:
def do_inference(server_addr, model_name, serialized_examples):
  """Sends requests to the model and prints the results.
  Args:
    server_addr: network address of model server in "host:port" format
    model_name: name of the model as understood by the model server
    serialized_examples: serialized examples of data to do inference on
  """
  parsed_server_addr = server_addr.split(':')

  host=parsed_server_addr[0]
  port=parsed_server_addr[1]
  json_examples = []
  
  for serialized_example in serialized_examples:
    # The encoding follows the guidelines in:
    # https://www.tensorflow.org/tfx/serving/api_rest
    example_bytes = base64.b64encode(serialized_example).decode('utf-8')
    predict_request = '{ "b64": "%s" }' % example_bytes
    json_examples.append(predict_request)

  json_request = '{ "instances": [' + ','.join(map(str, json_examples)) + ']}'

  server_url = 'http://' + host + ':' + port + '/v1/models/' + model_name + ':predict'
  response = requests.post(
      server_url, data=json_request, timeout=5.0)
  response.raise_for_status()
  prediction = response.json()
  print(json.dumps(prediction, indent=4))

In [0]:
do_inference(server_addr='127.0.0.1:8501', 
     model_name='online_news_simple',
     serialized_examples=serialized_examples)

Pipeline Complete!

In this example you created a TFX pipeline in a Colab notebook, using the InteractiveContext. Along the way you learned about each of the standard TFX components, but if the standard components don't meet all of your needs you can create your own custom components! Custom components will be covered in a later lesson.