In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook demonstrates how to build and use a custom component in your TFX pipeline. We will train an image classification model on the UC Merced Land Use Dataset of aerial pictures, using a custom component to perform image augmentation.
In [0]:
!pip install -q -U \
tensorflow-gpu==2.0.0 \
tfx==0.15.0rc0 \
tensorflow_datasets \
tensorflow-addons
In [0]:
import os
import tempfile
import urllib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
keras = tf.keras
K = keras.backend
import apache_beam as beam
import tensorflow_data_validation as tfdv
import tensorflow_datasets as tfds
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.eval_saved_model.export import build_parsing_eval_input_receiver_fn
from tensorflow_addons.image import rotate
import tfx
from tfx.components.base import base_component
from tfx.components.base import base_executor
from tfx.components.base import executor_spec
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.types import artifact_utils
from tfx.types import standard_artifacts
from tfx.types.component_spec import ChannelParameter
from tfx.types.component_spec import ExecutionParameter
from tfx.utils.dsl_utils import external_input
from tfx.proto import evaluator_pb2
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
Check the versions
In [0]:
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
In [0]:
train_set, ds_info = tfds.load(name="uc_merced",
split="train",
as_supervised=True,
with_info=True)
In [0]:
ds_info
In [0]:
n_classes = ds_info.features['label'].num_classes
n_classes
In [0]:
class_names = ds_info.features['label'].names
class_names
In [0]:
num_rows, num_cols = 10, 5
plt.figure(figsize=(4 * num_cols, 4 * num_rows))
for index, (image, label) in enumerate(train_set.take(num_rows * num_cols)):
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(image)
plt.title(class_names[label])
plt.axis('off')
plt.show()
Note that a few images are slightly smaller than 256x256:
In [0]:
for img, label in train_set:
if img.shape!=(256, 256, 3):
print(img.shape)
In [0]:
context = InteractiveContext()
In [0]:
HOME = os.path.expanduser('~')
examples_path = os.path.join(HOME, "tensorflow_datasets", "uc_merced", "0.0.1")
dataset = tf.data.TFRecordDataset(os.path.join(examples_path, "uc_merced-train.tfrecord-00000-of-00001"))
decoder = tfdv.TFExampleDecoder()
for tfrecord in dataset.take(1):
example = decoder.decode(tfrecord.numpy())
img = tf.io.decode_png(example['image'][0])
In [0]:
example
In [0]:
plt.imshow(img)
plt.axis('off')
plt.show()
In [0]:
img.shape
In [0]:
input_data = external_input(examples_path)
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='uc_merced-train*')])
#Or equivalently:
#input_config = tfx.components.example_gen.utils.make_default_input_config(
# split_pattern='uc_merced-train*')
example_gen = ImportExampleGen(input=input_data, input_config=input_config)
context.run(example_gen)
In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
examples=example_gen.outputs['examples'])
context.run(statistics_gen)
In [0]:
# Generates schema based on statistics files.
infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
context.run(infer_schema)
In [0]:
train_uri = infer_schema.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, "schema.pbtxt")
schema = tfx.utils.io_utils.parse_pbtxt_file(file_name=schema_filename,
message=schema_pb2.Schema())
In [0]:
tfdv.display_schema(schema)
In [0]:
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])
context.run(validate_stats)
In [0]:
"""Example of a custom TFX component for data augmentation.
This component along with other custom component related code will only serve as
an example and will not be supported by TFX team.
"""
class DataAugmentationComponentSpec(tfx.types.ComponentSpec):
"""ComponentSpec for custom TFX data augmentation component."""
PARAMETERS = {
'max_rotation_angle': ExecutionParameter(type=float),
'num_augmented_per_image': ExecutionParameter(type=int),
}
INPUTS = {
'examples': ChannelParameter(type=standard_artifacts.Examples),
}
OUTPUTS = {
'augmented_data': ChannelParameter(type=standard_artifacts.Examples),
}
def _dict_to_example(instance):
"""Feature dict to tf example."""
feature = {}
for key, value in instance.items():
if value is None:
feature[key] = tf.train.Feature()
elif value.dtype == np.integer:
feature[key] = tf.train.Feature(
int64_list=tf.train.Int64List(value=value.tolist()))
elif value.dtype == np.float32:
feature[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=value.tolist()))
else:
feature[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=value.tolist()))
return tf.train.Example(features=tf.train.Features(feature=feature))
def _augment_image(example, max_rotation_angle, num_augmented_per_image):
image = tf.image.decode_png(example["image"][0])
augmented_examples = [example.copy() for _ in range(num_augmented_per_image)]
for augmented_example in augmented_examples:
angle = tf.random.uniform([], -max_rotation_angle, +max_rotation_angle)
augmented_image = rotate(images=image, angles=angle)
augmented_example["image"] = np.array([
tf.image.encode_png(augmented_image).numpy()])
return [example] + augmented_examples
class DataAugmentationExecutor(base_executor.BaseExecutor):
"""Executor for custom TFX image augmentation component."""
def Do(self, input_dict, output_dict, exec_properties):
"""Perform transformations to the images to augment the training set.
Args:
input_dict: Input dict from input key to a list of artifacts, including:
- examples: transformed examples from the transform component.
output_dict: Output dict from key to a list of artifacts, including:
- augmented_data: augmented examples.
exec_properties: A dict of execution properties, including:
- max_rotation_angle: images will be rotated by a random angle between
—max_rotation_angle and +max_rotation_angle (in degrees)
- num_augmented_per_image: Number of augmented images.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
input_examples_uri = artifact_utils.get_split_uri(
input_dict['examples'], 'train')
output_examples_uri = artifact_utils.get_split_uri(
output_dict['augmented_data'], 'train')
decoder = tfdv.TFExampleDecoder()
with beam.Pipeline() as pipeline:
raw_data = (
pipeline
| 'ReadTrainData' >> beam.io.ReadFromTFRecord(input_examples_uri)
| 'ParseExample' >> beam.Map(decoder.decode)
| 'Augmentation' >> beam.ParDo(_augment_image, **exec_properties)
| 'DictToExample' >> beam.Map(_dict_to_example)
| 'SerializeExample' >> beam.Map(lambda x: x.SerializeToString())
| 'WriteAugmentedData' >> beam.io.WriteToTFRecord(
os.path.join(output_examples_uri, "data_tfrecord"),
file_name_suffix='.gz')
)
eval_input_examples_uri = artifact_utils.get_split_uri(
input_dict['examples'], 'eval')
eval_output_examples_uri = artifact_utils.get_split_uri(
output_dict['augmented_data'], 'eval')
with beam.Pipeline() as pipeline:
raw_data = (
pipeline
| 'ReadEvalData' >> beam.io.ReadFromTFRecord(eval_input_examples_uri)
| 'WriteAugmentedData' >> beam.io.WriteToTFRecord(
os.path.join(eval_output_examples_uri, "data_tfrecord"),
file_name_suffix='.gz')
)
class DataAugmentationComponent(base_component.BaseComponent):
"""Custom TFX image augmentation component.
This custom component will transform the training images some more after the
transform component to augment the training set. This augmentation will only
take place during training, not during eval or serving.
"""
SPEC_CLASS = DataAugmentationComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(DataAugmentationExecutor)
def __init__(self,
examples,
max_rotation_angle=10.,
num_augmented_per_image=2,
augmented_data=None,
instance_name=None):
"""Construct a DataAugmentationComponent.
Args:
examples: A Channel of 'Examples' type, usually produced by a Transform
component.
max_rotation_angle: images will be rotated by a random angle between
—max_rotation_angle and +max_rotation_angle (in degrees)
num_augmented_per_image: number of augmented images per original image,
not including the original image
augmented_data: A Channel of 'Examples' type for the output (augmented)
data.
instance_name: Optional unique instance name. Necessary if multiple
components of this class are declared in the same pipeline.
"""
augmented_data = augmented_data or tfx.types.Channel(
type=standard_artifacts.Examples,
artifacts=[standard_artifacts.Examples(split="train"),
standard_artifacts.Examples(split="eval")])
spec = DataAugmentationComponentSpec(
examples=examples,
max_rotation_angle=max_rotation_angle,
num_augmented_per_image=num_augmented_per_image,
augmented_data=augmented_data)
super().__init__(spec=spec, instance_name=instance_name)
In [0]:
data_augmentation = DataAugmentationComponent(
examples=example_gen.outputs['examples'],
max_rotation_angle=180.,
num_augmented_per_image=5)
context.run(data_augmentation)
In [0]:
train_uri = data_augmentation.outputs['augmented_data'].get()[0].uri
tfrecord_filenames = [os.path.join(train_uri, name)
for name in os.listdir(train_uri)]
augmented_train_set = tf.data.TFRecordDataset(tfrecord_filenames,
compression_type="GZIP")
In [0]:
num_rows, num_cols = 10, 5
plt.figure(figsize=(4 * num_cols, 4 * num_rows))
decoder = tfdv.TFExampleDecoder()
for index, tfrecord in enumerate(augmented_train_set.take(num_rows * num_cols)):
example = decoder.decode(tfrecord.numpy())
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(tf.image.decode_png(example["image"][0]))
plt.title(class_names[example["label"][0]])
plt.axis('off')
plt.show()
In [0]:
# Set up paths.
_transform_module_file = 'uc_merced_tranform.py'
In [0]:
%%writefile {_transform_module_file}
import tensorflow_transform as tft
import tensorflow as tf
LABEL_KEY = 'label'
def transformed_name(name):
return name + '_xf'
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
outputs = {}
for feature, value in inputs.items():
outputs[transformed_name(feature)] = _fill_in_missing(value)
return outputs
def _fill_in_missing(x):
"""Replace missing values in a SparseTensor.
Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
Args:
x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1
in the second dimension.
Returns:
A rank 1 tensor where missing values of `x` have been filled in.
"""
default_value = '' if x.dtype == tf.string else 0
return tf.squeeze(
tf.sparse.to_dense(
tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
default_value),
axis=1)
In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
examples=data_augmentation.outputs['augmented_data'],
schema=infer_schema.outputs['schema'],
module_file=_transform_module_file)
context.run(transform)
In [0]:
_trainer_module = 'uc_merced_trainer'
_trainer_module_file = _trainer_module + '.py'
_serving_model_dir = os.path.join(tempfile.mkdtemp(),
'serving_model/uc_merced_simple')
In [0]:
%%writefile {_trainer_module_file}
import tensorflow as tf
keras = tf.keras
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils
LABEL_KEY = 'label'
DROP_FEATURES = ["filename"]
NUM_CLASSES = 21
def transformed_name(name):
return name + '_xf'
# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
return schema_utils.schema_as_feature_spec(schema).feature_spec
def _gzip_reader_fn(filenames):
"""Small utility returning a record reader that can read gzip'ed files."""
return tf.data.TFRecordDataset(
filenames,
compression_type='GZIP')
@tf.function
def decode_and_resize(image):
return tf.image.resize(tf.io.decode_png(image), (256, 256))
@tf.function
def parse_png_images(png_images):
with tf.device("/cpu:0"):
flattened = tf.reshape(png_images, [-1])
decoded = tf.map_fn(decode_and_resize, flattened, dtype=tf.float32)
reshaped = tf.reshape(decoded, [-1, 256, 256, 3])
return reshaped / 255.
def _build_estimator(config, num_filters=None):
"""Build an estimator for classifying uc_merced images
Args:
config: tf.estimator.RunConfig defining the runtime environment for the
estimator (including model_dir).
num_filters: [int], number of filters per Conv2D layer
Returns:
The estimator that will be used for training and eval.
"""
model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=[1], dtype="string", name="image_xf"))
model.add(keras.layers.Lambda(parse_png_images))
for filters in num_filters:
model.add(keras.layers.Conv2D(filters=filters, kernel_size=3, activation="relu"))
model.add(keras.layers.MaxPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(NUM_CLASSES, activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam", metrics=["accuracy"])
return tf.keras.estimator.model_to_estimator(
keras_model=model,
config=config,
custom_objects={"parse_png_images": parse_png_images})
def _example_serving_receiver_fn(tf_transform_output, schema):
"""Build the serving in inputs.
Args:
tf_transform_output: A TFTransformOutput.
schema: the schema of the input data.
Returns:
Tensorflow graph which parses examples, applying tf-transform to them.
"""
raw_feature_spec = _get_raw_feature_spec(schema)
raw_feature_spec.pop(LABEL_KEY)
raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec, default_batch_size=None)
serving_input_receiver = raw_input_fn()
transformed_features = tf_transform_output.transform_raw_features(
serving_input_receiver.features)
for feature in DROP_FEATURES + [LABEL_KEY]:
transformed_features.pop(transformed_name(feature))
return tf.estimator.export.ServingInputReceiver(
transformed_features, serving_input_receiver.receiver_tensors)
def _eval_input_receiver_fn(tf_transform_output, schema):
"""Build everything needed for the tf-model-analysis to run the model.
Args:
tf_transform_output: A TFTransformOutput.
schema: the schema of the input data.
Returns:
EvalInputReceiver function, which contains:
- Tensorflow graph which parses raw untransformed features, applies the
tf-transform preprocessing operators.
- Set of raw, untransformed features.
- Label against which predictions will be compared.
"""
# Notice that the inputs are raw features, not transformed features here.
raw_feature_spec = _get_raw_feature_spec(schema)
serialized_tf_example = tf.compat.v1.placeholder(
dtype=tf.string, shape=[None], name='input_example_tensor')
# Add a parse_example operator to the tensorflow graph, which will parse
# raw, untransformed, tf examples.
features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
# Now that we have our raw examples, process them through the tf-transform
# function computed during the preprocessing step.
transformed_features = tf_transform_output.transform_raw_features(
features)
# The key name MUST be 'examples'.
receiver_tensors = {'examples': serialized_tf_example}
# NOTE: Model is driven by transformed features (since training works on the
# materialized output of TFT, but slicing will happen on raw features).
features.update(transformed_features)
for feature in DROP_FEATURES + [LABEL_KEY]:
if feature in features:
features.pop(feature)
if transformed_name(feature) in features:
features.pop(transformed_name(feature))
features.pop('image')
return tfma.export.EvalInputReceiver(
features=features,
receiver_tensors=receiver_tensors,
labels=transformed_features[transformed_name(LABEL_KEY)])
def _input_fn(filenames, tf_transform_output, batch_size=200):
"""Generates features and labels for training or evaluation.
Args:
filenames: [str] list of CSV files to read data from.
tf_transform_output: A TFTransformOutput.
batch_size: int First dimension size of the Tensors returned by input_fn
Returns:
A (features, indices) tuple where features is a dictionary of
Tensors, and indices is a single Tensor of label indices.
"""
transformed_feature_spec = (
tf_transform_output.transformed_feature_spec().copy())
dataset = tf.data.experimental.make_batched_features_dataset(
filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)
transformed_features = dataset.make_one_shot_iterator().get_next()
for feature in DROP_FEATURES:
transformed_features.pop(transformed_name(feature))
return transformed_features, transformed_features.pop(
transformed_name(LABEL_KEY))
# TFX will call this function
def trainer_fn(hparams, schema):
"""Build the estimator using the high level API.
Args:
hparams: Holds hyperparameters used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
train_batch_size = 40
eval_batch_size = 40
num_cnn_layers = 4
first_cnn_filters = 32
tf_transform_output = tft.TFTransformOutput(hparams.transform_output)
train_input_fn = lambda: _input_fn(
hparams.train_files,
tf_transform_output,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn(
hparams.eval_files,
tf_transform_output,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec(
train_input_fn,
max_steps=hparams.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn(
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('uc-merced', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input_fn,
steps=hparams.eval_steps,
exporters=[exporter],
name='uc-merced-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
run_config = run_config.replace(model_dir=hparams.serving_model_dir)
num_filters = [first_cnn_filters]
for layer_index in range(1, num_cnn_layers):
num_filters.append(num_filters[-1] * 2)
estimator = _build_estimator(
config=run_config,
num_filters=num_filters)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn(
tf_transform_output, schema)
return {
'estimator': estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}
In [0]:
# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
trainer_fn="{}.trainer_fn".format(_trainer_module),
transformed_examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=200),
eval_args=trainer_pb2.EvalArgs(num_steps=100))
context.run(trainer)
In [0]:
# Uses TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['label_xf'])
]))
context.run(model_analyzer)
In [0]:
evaluation_uri = model_analyzer.outputs['output'].get()[0].uri
eval_result = tfma.load_eval_result(evaluation_uri)
eval_result
In [0]:
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
context.run(model_validator)
In [0]:
blessing_uri = model_validator.outputs['blessing'].get()[0].uri
!ls -l {blessing_uri}
In [0]:
# Setup serving path
_serving_model_dir = os.path.join(tempfile.mkdtemp(),
'serving_model/uc_merced_simple')
In [0]:
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model_export=trainer.outputs['model'],
model_blessing=model_validator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=_serving_model_dir)))
context.run(pusher)