Copyright © 2020 The TensorFlow Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Pipeline Example performing the Bert Preprocessing with TensorFlow Transform

Motivation

Instead of converting the input to a tranformer model into token ids on the client side, the model exported from this pipeline will allow the conversion on the server side.

The pipeline takes advantage of the broad TensorFlow Eco system, including:

  • Loading the IMDB dataset via TensorFlow Datasets
  • Loading a pre-trained model via tf.hub
  • Manipulating the raw input data with tf.text
  • Building a simple model architecture with Keras
  • Composing the model pipeline with TensorFlow Extended, e.g. TensorFlow Transform, TensorFlow Data Validation and then consuming the tf.Keras model with the latest Trainer component from TFX

The structure of the overall pipeline follows the

Outline

  • Install Required Packages
  • Load the training data set
  • Create the TFX Pipeline
  • Export the trained Model
  • Test the exported Model

Non-Colab users

This notebook was written to run in Google Colab environments. But they can run in any Jupyter environment. In that case, update the file and directory path and install TensorFlow>=2.2.0 manually.

Project Setup

Install Required Packages


In [1]:
try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass

!pip install -Uq tfx==0.21.4
!pip install -Uq tensorflow-text  # the tf-text version needs to match the tf version
!pip install -Uq tensorflow-model-analysis==0.22.1
!pip install -Uq tensorflow-data-validation==0.22.0
!pip install -Uq tensorflow-transform==0.22.0

print("Restart your runtime enable after installing the packages")


     |████████████████████████████████| 102kB 2.5MB/s 
ERROR: tensorflow-transform 0.21.2 has requirement tensorflow<2.2,>=1.15, but you'll have tensorflow 2.2.0 which is incompatible.
ERROR: multiprocess 0.70.9 has requirement dill>=0.3.1, but you'll have dill 0.3.0 which is incompatible.
ERROR: apache-beam 2.17.0 has requirement httplib2<=0.12.0,>=0.8, but you'll have httplib2 0.17.3 which is incompatible.
ERROR: tfx 0.21.4 has requirement apache-beam[gcp]<2.18,>=2.17, but you'll have apache-beam 2.20.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement pyarrow<0.16,>=0.15, but you'll have pyarrow 0.16.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement tensorflow-model-analysis<0.22,>=0.21.4, but you'll have tensorflow-model-analysis 0.22.1 which is incompatible.
ERROR: tfx 0.21.4 has requirement tfx-bsl<0.22,>=0.21.3, but you'll have tfx-bsl 0.22.0 which is incompatible.
ERROR: tensorflow-transform 0.21.2 has requirement tensorflow<2.2,>=1.15, but you'll have tensorflow 2.2.0 which is incompatible.
ERROR: tensorflow-transform 0.21.2 has requirement tensorflow-metadata<0.22,>=0.21, but you'll have tensorflow-metadata 0.22.0 which is incompatible.
ERROR: tensorflow-transform 0.21.2 has requirement tfx-bsl<0.22,>=0.21.3, but you'll have tfx-bsl 0.22.0 which is incompatible.
ERROR: tensorflow-data-validation 0.21.5 has requirement tensorflow-metadata<0.22,>=0.21.1, but you'll have tensorflow-metadata 0.22.0 which is incompatible.
ERROR: tensorflow-data-validation 0.21.5 has requirement tfx-bsl<0.22,>=0.21.3, but you'll have tfx-bsl 0.22.0 which is incompatible.
ERROR: google-api-python-client 1.7.12 has requirement httplib2<1dev,>=0.17.0, but you'll have httplib2 0.12.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement apache-beam[gcp]<2.18,>=2.17, but you'll have apache-beam 2.20.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement pyarrow<0.16,>=0.15, but you'll have pyarrow 0.16.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement tensorflow-data-validation<0.22,>=0.21.4, but you'll have tensorflow-data-validation 0.22.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement tensorflow-model-analysis<0.22,>=0.21.4, but you'll have tensorflow-model-analysis 0.22.1 which is incompatible.
ERROR: tfx 0.21.4 has requirement tensorflow-transform<0.22,>=0.21.2, but you'll have tensorflow-transform 0.22.0 which is incompatible.
ERROR: tfx 0.21.4 has requirement tfx-bsl<0.22,>=0.21.3, but you'll have tfx-bsl 0.22.0 which is incompatible.
Restart your runtime enable after installing the packages

Restart the Runtime

Note After installing the required Python packages, you'll need to restart the Colab Runtime Engine (Menu > Runtime > Restart runtime...)

Import relevant packages


In [0]:
import glob
import os
import pprint
import re
import tempfile
from shutil import rmtree

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_data_validation as tfdv
import tensorflow_hub as hub
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.saved import saved_transform_io
from tensorflow_transform.tf_metadata import (dataset_metadata, dataset_schema,
                                              metadata_io, schema_utils)
from tfx.components import (Evaluator, ExampleValidator, ImportExampleGen,
                            ModelValidator, Pusher, ResolverNode, SchemaGen,
                            StatisticsGen, Trainer, Transform)
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.proto import evaluator_pb2, example_gen_pb2, pusher_pb2, trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model, ModelBlessing
from tfx.utils.dsl_utils import external_input

import tensorflow_datasets as tfds
import tensorflow_model_analysis as tfma
import tensorflow_text as text

from tfx.orchestration.experimental.interactive.interactive_context import \
    InteractiveContext

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip

Check GPU Availability

Check if your Colab notebook is configured to use Graphical Processing Units (GPUs). If zero GPUs are available, check if the Colab notebook is configured to use GPUs (Menu > Runtime > Change Runtime Type).


In [3]:
num_gpus_available = len(tf.config.experimental.list_physical_devices('GPU'))
print("Num GPUs Available: ", num_gpus_available)
assert num_gpus_available > 0


Num GPUs Available:  1

Download the IMDB Dataset from TensorFlow Datasets

For our demo example, we are using the IMDB data set to train a sentiment model based on the pre-trained BERT model. The data set is provided through TensorFlow Datasets. Our ML pipeline can read TFRecords, however it expects only TFRecord files in the data folder. That is the reason why we need to delete the additional files provided by TFDS.


In [4]:
!mkdir /content/tfds/

def clean_before_download(base_data_dir):
    rmtree(base_data_dir)
    
def delete_unnecessary_files(base_path):
    os.remove(base_path + "dataset_info.json")
    os.remove(base_path + "label.labels.txt")
    
    counter = 2
    for f in glob.glob(base_path + "imdb_reviews-unsupervised.*"):
        os.remove(f)
        counter += 1
    print(f"Deleted {counter} files")

def get_dataset(name='imdb_reviews', version="0.1.0"):

    base_data_dir = "/content/tfds/"
    config="plain_text"
    version="0.1.0"

    clean_before_download(base_data_dir)
    tfds.disable_progress_bar()
    builder = tfds.text.IMDBReviews(data_dir=base_data_dir, 
                                    config=config, 
                                    version=version)
    download_config = tfds.download.DownloadConfig(
        download_mode=tfds.GenerateMode.FORCE_REDOWNLOAD)
    builder.download_and_prepare(download_config=download_config)

    base_tfrecords_filename = os.path.join(base_data_dir, "imdb_reviews", config, version, "")
    train_tfrecords_filename = base_tfrecords_filename + "imdb_reviews-train*"
    test_tfrecords_filename = base_tfrecords_filename + "imdb_reviews-test*"
    label_filename = os.path.join(base_tfrecords_filename, "label.labels.txt")
    labels = [label.rstrip('\n') for label in open(label_filename)]
    delete_unnecessary_files(base_tfrecords_filename)
    return (train_tfrecords_filename, test_tfrecords_filename), labels

tfrecords_filenames, labels = get_dataset()


Downloading and preparing dataset imdb_reviews/plain_text/0.1.0 (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /content/tfds/imdb_reviews/plain_text/0.1.0...
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_datasets/core/file_format_adapter.py:210: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
Dataset imdb_reviews downloaded and prepared to /content/tfds/imdb_reviews/plain_text/0.1.0. Subsequent calls will reuse this data.
Deleted 22 files

Helper function to load the BERT model as Keras layer

In our pipeline components, we are reusing the BERT Layer from tf.hub in two places

  • in the model architecture when we define our Keras model
  • in our preprocessing function when we extract the BERT settings (casing and vocab file path) to reuse the settings during the tokenization

In [5]:
%%skip_for_export
%%writefile bert.py

import tensorflow_hub as hub

BERT_TFHUB_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2"

def load_bert_layer(model_url=BERT_TFHUB_URL):
    # Load the pre-trained BERT model as layer in Keras
    bert_layer = hub.KerasLayer(
        handle=model_url,
        trainable=True)
    return bert_layer


Writing bert.py
This cell will be skipped during export to pipeline.

TFX Pipeline

The TensorFlow Extended Pipeline is more or less following the example setup shown here. We'll only note deviations from the original setup.

Initializing the Interactive TFX Pipeline


In [6]:
context = InteractiveContext()


WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/metadata.sqlite.

Loading the dataset


In [7]:
output = example_gen_pb2.Output(
             split_config=example_gen_pb2.SplitConfig(splits=[
                 example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=45),
                 example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=5)
             ]))
# Load the data from our prepared TFDS folder
examples = external_input("/content/tfds/imdb_reviews/plain_text/0.1.0")
example_gen = ImportExampleGen(input=examples, output_config=output)

context.run(example_gen)


WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
Out[7]:
ExecutionResult at 0x7f588fc5c8d0
.execution_id1
.component
.component.inputs
['input']
.component.outputs
['examples']

In [8]:
%%skip_for_export

for artifact in example_gen.outputs['examples'].get():
    print(artifact.uri)


/tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/ImportExampleGen/examples/1
This cell will be skipped during export to pipeline.

TensorFlow Data Validation


In [9]:
%%skip_for_export

statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen)

context.show(statistics_gen.outputs['statistics'])


Artifact at /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/StatisticsGen/statistics/2

'train' split:


'eval' split:


This cell will be skipped during export to pipeline.

In [10]:
%%skip_for_export

schema_gen = SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=True)
context.run(schema_gen)

context.show(schema_gen.outputs['schema'])


Artifact at /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/SchemaGen/schema/3

Type Presence Valency Domain
Feature name
'text' BYTES required -
'label' INT required -
This cell will be skipped during export to pipeline.

In [11]:
%%skip_for_export

# check the data schema for the type of input tensors
tfdv.load_schema_text(schema_gen.outputs['schema'].get()[0].uri + "/schema.pbtxt")


Out[11]:
feature {
  name: "text"
  type: BYTES
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "label"
  type: INT
  bool_domain {
  }
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
This cell will be skipped during export to pipeline.

In [12]:
%%skip_for_export

example_validator = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator)

context.show(example_validator.outputs['anomalies'])


Artifact at /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/ExampleValidator/anomalies/4

No anomalies found.

This cell will be skipped during export to pipeline.

TensorFlow Transform

This is where we perform the BERT processing.


In [13]:
%%skip_for_export
%%writefile transform.py

import tensorflow as tf
import tensorflow_text as text

from bert import load_bert_layer

MAX_SEQ_LEN = 64  # max number is 512
do_lower_case = load_bert_layer().resolved_object.do_lower_case.numpy()

def preprocessing_fn(inputs):
    """Preprocess input column of text into transformed columns of.
        * input token ids
        * input mask
        * input type ids
    """

    CLS_ID = tf.constant(101, dtype=tf.int64)
    SEP_ID = tf.constant(102, dtype=tf.int64)
    PAD_ID = tf.constant(0, dtype=tf.int64)

    vocab_file_path = load_bert_layer().resolved_object.vocab_file.asset_path
    
    bert_tokenizer = text.BertTokenizer(vocab_lookup_table=vocab_file_path, 
                                        token_out_type=tf.int64, 
                                        lower_case=do_lower_case) 
    
    def tokenize_text(text, sequence_length=MAX_SEQ_LEN):
        """
        Perform the BERT preprocessing from text -> input token ids
        """

        # convert text into token ids
        tokens = bert_tokenizer.tokenize(text)
        
        # flatten the output ragged tensors 
        tokens = tokens.merge_dims(1, 2)[:, :sequence_length]
        
        # Add start and end token ids to the id sequence
        start_tokens = tf.fill([tf.shape(text)[0], 1], CLS_ID)
        end_tokens = tf.fill([tf.shape(text)[0], 1], SEP_ID)
        tokens = tokens[:, :sequence_length - 2]
        tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)

        # truncate sequences greater than MAX_SEQ_LEN
        tokens = tokens[:, :sequence_length]

        # pad shorter sequences with the pad token id
        tokens = tokens.to_tensor(default_value=PAD_ID)
        pad = sequence_length - tf.shape(tokens)[1]
        tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=PAD_ID)

        # and finally reshape the word token ids to fit the output 
        # data structure of TFT  
        return tf.reshape(tokens, [-1, sequence_length])

    def preprocess_bert_input(text):
        """
        Convert input text into the input_word_ids, input_mask, input_type_ids
        """
        input_word_ids = tokenize_text(text)
        input_mask = tf.cast(input_word_ids > 0, tf.int64)
        input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])
        
        zeros_dims = tf.stack(tf.shape(input_mask))
        input_type_ids = tf.fill(zeros_dims, 0)
        input_type_ids = tf.cast(input_type_ids, tf.int64)

        return (
            input_word_ids, 
            input_mask,
            input_type_ids
        )

    input_word_ids, input_mask, input_type_ids = \
        preprocess_bert_input(tf.squeeze(inputs['text'], axis=1))

    return {
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids,
        'label': inputs['label']
    }


Writing transform.py
This cell will be skipped during export to pipeline.

In [14]:
transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath("transform.py"))
context.run(transform)


WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tfx/components/transform/executor.py:511: Schema (from tensorflow_transform.tf_metadata.dataset_schema) is deprecated and will be removed in a future version.
Instructions for updating:
Schema is a deprecated, use schema_utils.schema_from_feature_spec to create a `Schema`
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_transform/tf_utils.py:220: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:Tensorflow version (2.2.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:apache_beam.utils.interactive_utils:Failed to alter the label of a transform with the ipython prompt metadata. Cannot figure out the pipeline that the given pvalueish ({DatasetKey(key='tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd-ImportExampleGen-examples-1-train-STAR'): <PCollection[Decode[AnalysisIndex0]/ApplyDecodeFn.None] at 0x7f5800792828>}, None, {'_schema': feature {
  name: "text"
  type: BYTES
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "label"
  type: INT
  bool_domain {
  }
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
}) belongs to. Thus noop.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/Transform/transform_graph/5/.temp_path/tftransform_tmp/62fe0ce1bdf44fb5bd5020d2b4728748/assets
INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/Transform/transform_graph/5/.temp_path/tftransform_tmp/62fe0ce1bdf44fb5bd5020d2b4728748/saved_model.pb
WARNING:apache_beam.utils.interactive_utils:Failed to alter the label of a transform with the ipython prompt metadata. Cannot figure out the pipeline that the given pvalueish {'_schema': feature {
  name: "text"
  type: BYTES
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "label"
  type: INT
  bool_domain {
  }
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
} belongs to. Thus noop.
WARNING:tensorflow:Tensorflow version (2.2.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:apache_beam.utils.interactive_utils:Failed to alter the label of a transform with the ipython prompt metadata. Cannot figure out the pipeline that the given pvalueish ((<PCollection[Decode[TransformIndex0]/ApplyDecodeFn.None] at 0x7f5800c5f7b8>, {'_schema': feature {
  name: "label"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "text"
  type: BYTES
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
}), (<PCollection[Analyze/CreateSavedModel/BindTensors/CreateSavedModel/Map(decode).None] at 0x7f5800c5fa58>, BeamDatasetMetadata(dataset_metadata={'_schema': feature {
  name: "input_mask"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "input_type_ids"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "input_word_ids"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "label"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
}, deferred_metadata=<PCollection[Analyze/ComputeDeferredMetadata.None] at 0x7f5800c5ff98>))) belongs to. Thus noop.
WARNING:tensorflow:Tensorflow version (2.2.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:apache_beam.utils.interactive_utils:Failed to alter the label of a transform with the ipython prompt metadata. Cannot figure out the pipeline that the given pvalueish ((<PCollection[Decode[TransformIndex1]/ApplyDecodeFn.None] at 0x7f5828f13c18>, {'_schema': feature {
  name: "label"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "text"
  type: BYTES
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
}), (<PCollection[Analyze/CreateSavedModel/BindTensors/CreateSavedModel/Map(decode).None] at 0x7f5800c5fa58>, BeamDatasetMetadata(dataset_metadata={'_schema': feature {
  name: "input_mask"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "input_type_ids"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "input_word_ids"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 64
    }
  }
}
feature {
  name: "label"
  type: INT
  presence {
    min_fraction: 1.0
  }
  shape {
    dim {
      size: 1
    }
  }
}
}, deferred_metadata=<PCollection[Analyze/ComputeDeferredMetadata.None] at 0x7f5800c5ff98>))) belongs to. Thus noop.
WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\016\n\014asset_path:0\022\tvocab.txt"

WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\016\n\014asset_path:0\022\tvocab.txt"

WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\016\n\014asset_path:0\022\tvocab.txt"

Out[14]:
ExecutionResult at 0x7f5870e55828
.execution_id5
.component
.component.inputs
['examples']
['schema']
.component.outputs
['transform_graph']
['transformed_examples']

Check the Output Data Struture of the TF Transform Operation


In [15]:
from tfx_bsl.coders.example_coder import ExampleToNumpyDict

pp = pprint.PrettyPrinter()

# Get the URI of the output artifact representing the transformed examples, which is a directory
train_uri = transform.outputs['transformed_examples'].get()[0].uri

print(train_uri)

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_folders = [os.path.join(train_uri, name) for name in os.listdir(train_uri)]
tfrecord_filenames = []
for tfrecord_folder in tfrecord_folders:
    for name in os.listdir(tfrecord_folder):
        tfrecord_filenames.append(os.path.join(tfrecord_folder, name))


# Create a TFRecordDataset to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

for tfrecord in dataset.take(1):
    serialized_example = tfrecord.numpy()
    example = ExampleToNumpyDict(serialized_example)
    pp.pprint(example)


/tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/Transform/transformed_examples/5
{'input_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'input_type_ids': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'input_word_ids': array([  101,  1036,  5192,  3894,  1005, 27639,  2015,  1996,  6569,
        1998, 21606,  1997,  1996,  2034,  3185,  9501,  1012,  2009,
        2036,  3065,  1996,  2373,  1997,  2143,  1999,  2049,  3754,
        2000,  3288,  1996,  2088,  1037,  2210,  3553,  1010,  9462,
        3451, 13500,  1998,  2000,  7969,  9731,  2005,  8213,  2664,
        2000,  2272,  1012,  5121,  1010,  3087,  2040,  5621,  7459,
        1996,  2396,  1997,  1996,  4367,  3861,  2097,  5959,  2023,
         102]),
 'label': array([1])}

Training of the Keras Model


In [18]:
%%skip_for_export
%%writefile trainer.py

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

from typing import Text

import absl
import tensorflow as tf
from tensorflow import keras
import tensorflow_transform as tft
from tfx.components.trainer.executor import TrainerFnArgs


_LABEL_KEY = 'label'
BERT_TFHUB_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2"


def _gzip_reader_fn(filenames):
    """Small utility returning a record reader that can read gzip'ed files."""
    return tf.data.TFRecordDataset(filenames, compression_type='GZIP')

def load_bert_layer(model_url=BERT_TFHUB_URL):
    # Load the pre-trained BERT model as layer in Keras
    bert_layer = hub.KerasLayer(
        handle=model_url,
        trainable=False)  # model can be fine-tuned 
    return bert_layer

def get_model(tf_transform_output, max_seq_length=64, num_labels=2):

    # dynamically create inputs for all outputs of our transform graph
    feature_spec = tf_transform_output.transformed_feature_spec()  
    feature_spec.pop(_LABEL_KEY)

    inputs = {
        key: tf.keras.layers.Input(shape=(max_seq_length), name=key, dtype=tf.int64)
            for key in feature_spec.keys()
    }

    input_word_ids = tf.cast(inputs["input_word_ids"], dtype=tf.int32)
    input_mask = tf.cast(inputs["input_mask"], dtype=tf.int32)
    input_type_ids = tf.cast(inputs["input_type_ids"], dtype=tf.int32)

    bert_layer = load_bert_layer()
    pooled_output, _ = bert_layer(
        [input_word_ids, 
         input_mask, 
         input_type_ids
        ]
    )
    
    # Add additional layers depending on your problem
    x = tf.keras.layers.Dense(256, activation='relu')(pooled_output)
    dense = tf.keras.layers.Dense(64, activation='relu')(x)
    pred = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    keras_model = tf.keras.Model(
        inputs=[
                inputs['input_word_ids'], 
                inputs['input_mask'], 
                inputs['input_type_ids']], 
        outputs=pred)
    keras_model.compile(loss='binary_crossentropy', 
                        optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), 
                        metrics=['accuracy']
                        )
    return keras_model


def _get_serve_tf_examples_fn(model, tf_transform_output):
    """Returns a function that parses a serialized tf.Example and applies TFT."""

    model.tft_layer = tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
        """Returns the output to be used in the serving signature."""
        feature_spec = tf_transform_output.raw_feature_spec()
        feature_spec.pop(_LABEL_KEY)
        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

        transformed_features = model.tft_layer(parsed_features)

        outputs = model(transformed_features)
        return {'outputs': outputs}

    return serve_tf_examples_fn

def _input_fn(file_pattern: Text,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 32) -> tf.data.Dataset:
    """Generates features and label for tuning/training.

    Args:
      file_pattern: input tfrecord file pattern.
      tf_transform_output: A TFTransformOutput.
      batch_size: representing the number of consecutive elements of returned
        dataset to combine in a single batch

    Returns:
      A dataset that contains (features, indices) tuple where features is a
        dictionary of Tensors, and indices is a single Tensor of label indices.
    """
    transformed_feature_spec = (
        tf_transform_output.transformed_feature_spec().copy())

    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern=file_pattern,
        batch_size=batch_size,
        features=transformed_feature_spec,
        reader=_gzip_reader_fn,
        label_key=_LABEL_KEY)

    return dataset

# TFX Trainer will call this function.
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

    Args:
      fn_args: Holds args used to train the model as name/value pairs.
    """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 32)
    eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 32)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = get_model(tf_transform_output=tf_transform_output)

    model.fit(
        train_dataset,
        steps_per_epoch=fn_args.train_steps,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps)

    signatures = {
        'serving_default':
            _get_serve_tf_examples_fn(model,
                                      tf_transform_output).get_concrete_function(
                                          tf.TensorSpec(
                                              shape=[None],
                                              dtype=tf.string,
                                              name='examples')),
    }
    model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)


Overwriting trainer.py
This cell will be skipped during export to pipeline.

In [19]:
# NOTE: Adjust the number of training and evaluation steps
TRAINING_STEPS = 10000
EVALUATION_STEPS = 1000

trainer = Trainer(
    module_file=os.path.abspath("trainer.py"),
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=TRAINING_STEPS),
    eval_args=trainer_pb2.EvalArgs(num_steps=EVALUATION_STEPS))
context.run(trainer)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
10000/10000 [==============================] - ETA: 0s - accuracy: 0.7038 - loss: 0.5696INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
10000/10000 [==============================] - 884s 88ms/step - accuracy: 0.7038 - loss: 0.5696 - val_accuracy: 0.6818 - val_loss: 0.6058
WARNING:tensorflow:Tensorflow version (2.2.0) found. TransformFeaturesLayer may not work as intended if the SavedModel contains an initialization op.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-05-19T19_40_33.919799-61fc71kd/Trainer/model/7/serving_model_dir/assets
Out[19]:
ExecutionResult at 0x7f57bbb15f60
.execution_id7
.component
.component.inputs
['examples']
['transform_graph']
['schema']
.component.outputs
['model']

In [20]:
model_resolver = ResolverNode(
    instance_name='latest_blessed_model_resolver',
    resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
    model=Channel(type=Model),
    model_blessing=Channel(type=ModelBlessing))

context.run(model_resolver)


Out[20]:
ExecutionResult at 0x7f5886d68320
.execution_id8
.component<tfx.components.common_nodes.resolver_node.ResolverNode object at 0x7f57bbb15400>
.component.inputs
['model']
['model_blessing']
.component.outputs
['model']
['model_blessing']

TensorFlow Model Evaluation


In [21]:
eval_config = tfma.EvalConfig(
    model_specs=[
        tfma.ModelSpec(label_key='label')
    ],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount')
            ],
            thresholds = {
                'binary_accuracy': tfma.MetricThreshold(
                    value_threshold=tfma.GenericValueThreshold(
                        lower_bound={'value': 0.5}),
                    change_threshold=tfma.GenericChangeThreshold(
                       direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                       absolute={'value': -1e-10}))
            }
        )
    ],
    slicing_specs=[
        # An empty slice spec means the overall slice, i.e. the whole dataset.
        tfma.SlicingSpec(),
    ])

evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config
)

context.run(evaluator)


WARNING:absl:inputs do not match those expected by the model: input_names=['examples'], found in extracts={}
Out[21]:
ExecutionResult at 0x7f57bbc05f60
.execution_id9
.component
.component.inputs
['examples']
['model']
['baseline_model']
.component.outputs
['evaluation']
['blessing']

In [22]:
# Check the blessing
!ls {evaluator.outputs['blessing'].get()[0].uri}


BLESSED

Model Export for Serving


In [23]:
!mkdir /content/serving_model_dir

serving_model_dir = "/content/serving_model_dir"

pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=serving_model_dir)))

context.run(pusher)


Out[23]:
ExecutionResult at 0x7f57bbb15828
.execution_id10
.component
.component.inputs
['model']
['model_blessing']
.component.outputs
['pushed_model']

Test your Exported Model


In [24]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

push_uri = pusher.outputs.model_push.get()[0].uri
latest_version = max(os.listdir(push_uri))
latest_version_path = os.path.join(push_uri, latest_version)
loaded_model = tf.saved_model.load(latest_version_path)

example_str = b"This is the finest show ever produced for TV. Each episode is a triumph. The casting, the writing, the timing are all second to none. This cast performs miracles."
example = tf.train.Example(features=tf.train.Features(feature={
    'text': _bytes_feature(example_str)}))

serialized_example = example.SerializeToString()
f = loaded_model.signatures["serving_default"]
print(f(tf.constant([serialized_example])))


{'outputs': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.7497024]], dtype=float32)>}

Upload the Exported Model to GDrive


In [32]:
from google.colab import drive
drive.mount('/content/drive')

!mkdir /content/drive/My\ Drive/exported_model
!cp -r {pusher.outputs.model_push.get()[0].uri} /content/drive/My\ Drive/exported_model/

drive.flush_and_unmount()
print('Exported model has been uploaded to your Google Drive.')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
mkdir: cannot create directory ‘/content/drive/My Drive/exported_model’: File exists
Exported model has been uploaded to your Google Drive.