Copyright © 2020 The TensorFlow Authors.

In [ ]:
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Distributed PCA Using TensorFlow Extended (TFX)

Setup

First, install dependencies, import modules, set up paths, and download data.


In [0]:
!pip install -U tensorflow==2.0.0 pyarrow==0.14.1 tfx==0.15.0 grpcio==1.24.3 matplotlib==3.1.2
!pip freeze | grep -e tensorflow -e tfx -e pyarrow

Restart the Runtime

Note After installing the required Python packages, you'll need to restart the Colab Runtime Engine (Menu > Runtime > Restart runtime...)

Import packages

We import necessary packages, including standard TFX component classes.


In [0]:
import os
import absl
import tempfile
import urllib
import numpy as np
from matplotlib import pyplot as plt    
import matplotlib.patches as mpatches
import tensorflow as tf
import tensorflow_transform as tft
import tfx

from tfx.components import CsvExampleGen
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Transform
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.utils.dsl_utils import external_input
from tfx.types import artifact_utils

Set up pipeline paths and logging


In [0]:
# Set up paths.
_tfx_root = tfx.__path__[0]
_taxi_root = os.path.join(_tfx_root, 'examples/iris_pca_example')

# Set up logging.
absl.logging.set_verbosity(absl.logging.INFO)

Download example data

We download the sample dataset for use in our TFX pipeline.


In [0]:
# Download the example data.
_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/iris/data/iris.csv'
with open(os.path.join(_data_root, 'data.csv'), 'wb') as f:
    contents = urllib.request.urlopen(DATA_PATH).read()
    f.write(contents)

Create the InteractiveContext

We now create the interactive context.


In [0]:
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext. Calls to InteractiveContext are no-ops outside of the
# notebook.
context = InteractiveContext()

Run TFX components interactively

Next, we construct TFX components and run each one interactively using within the interactive session to obtain ExecutionResult objects.

ExampleGen

ExampleGen brings data into the TFX pipeline.


In [0]:
# Use the packaged CSV input data.
examples = external_input(_data_root)

# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
context.run(example_gen)

StatisticsGen (using Tensorflow Data Validation)

StatisticsGen computes statistics for visualization and example validation. This uses the Tensorflow Data Validation library.

Run TFDV statistics computation using the StatisticsGen component


In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen)

Visualize the statistics result


In [0]:
context.show(statistics_gen.outputs['statistics'])

SchemaGen (using Tensorflow Data Validation)

SchemaGen generates a schema for your data based on computed statistics. This component also uses the Tensorflow Data Validation library.

Run TFDV schema inference using the SchemaGen component


In [0]:
# Generates schema based on statistics files.
schema_gen = SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen)

Visualize the inferred schema


In [0]:
context.show(schema_gen.outputs['schema'])

Transform

Transform performs data transformations and feature engineering which are kept in sync for training and serving.


In [0]:
_iris_pca_transform_module_file = 'iris_pca_transform.py'

In [0]:
%%writefile {_iris_pca_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

LABEL_KEY = 'variety'

def _fill_in_missing(x):
    """Replace missing valueimport numpy as np
    Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
    Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
    Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
    """
    default_value = '' if x.dtype == tf.string else 0
    return tf.sparse.to_dense(tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), 
                              default_value)
      
def preprocessing_fn(inputs):
    features = []
    outputs = {
        LABEL_KEY: _fill_in_missing(inputs[LABEL_KEY])
    }
    
    for feature_name, feature_tensor in inputs.items():
        if feature_name != LABEL_KEY:
            features.append(tft.scale_to_z_score( # standard scaler pre-req for PCA
                _fill_in_missing(feature_tensor)         # filling in missing values
            ))

    # concat to make feature matrix for PCA to run over
    feature_matrix = tf.concat(features, axis=1)  
    
    # get orthonormal vector matix
    orthonormal_vectors = tft.pca(feature_matrix, output_dim=2, dtype=tf.float32)
    
    # multiply matrix by feature matrix to get transformation
    pca_examples = tf.linalg.matmul(feature_matrix, orthonormal_vectors)
    
    # unstack and add to output dict
    pca_examples = tf.unstack(pca_examples, axis=1)
    outputs['Principal Component 1'] = pca_examples[0]
    outputs['Principal Component 2'] = pca_examples[1]
      
    return outputs

Run the Transform component


In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_iris_pca_transform_module_file))
context.run(transform)

Get transformed examples, for both train and eval set


In [0]:
# Generating a tf.data.Dataset for the TFRecords
transformed_examples_paths = [
    artifact_utils.get_split_uri(transform.outputs['transformed_examples']._artifacts, 'train') + '*',
    artifact_utils.get_split_uri(transform.outputs['transformed_examples']._artifacts, 'eval') + '*'
]
transformed_examples_path = [f for f in tf.data.Dataset.list_files(transformed_examples_paths)]
transformed_dataset = tf.data.TFRecordDataset(transformed_examples_path, compression_type="GZIP")

# Utilizing tft.TFTransformOutput to get feature_spec
transform_output = artifact_utils.get_single_uri(transform.outputs['transform_output']._artifacts)
tf_transform_output = tft.TFTransformOutput(transform_output)
feature_spec = tf_transform_output.transformed_feature_spec()

# Parsing raw TFRecord into TFExamples
def _parse_function(example_proto):
    # Parse the input `tf.Example` proto using the feature_spec above.
    return tf.io.parse_single_example(example_proto, feature_spec)

transformed_dataset = transformed_dataset.map(_parse_function)
print("Feature Spec: ")
feature_spec

Extract transformed data to standard numpy arrays


In [0]:
pc1 , pc2, label_data = [], [], []
label_data = []
for f_dict in transformed_dataset:
    pc1.append(f_dict['Principal Component 1'].numpy())
    pc2.append(f_dict['Principal Component 2'].numpy())
    label_data.append(f_dict['variety'].numpy()[0])
pc1, pc2, label_data = np.array(pc1), np.array(pc2), np.array(label_data)
pc1.shape[0], pc2.shape[0], label_data.shape[0]  # Should all be equal

Plot PCA data


In [0]:
fig = plt.figure(figsize = (10,6))
ax = fig.add_subplot(1,1,1) 
ax.set_title('Iris Dataset PCA', fontsize=20)

for i in range(0, label_data.shape[0]):
    color = 'red' if label_data[i]==0 else 'green' if label_data[i]==1 else 'purple'
    ax.scatter(x=pc1[i], y=pc2[i], c=color, label=label_data[i], alpha=0.7)
    
ax.set_xlabel("Principal Component 1")
ax.set_ylabel("Principal Component 2")
plt.legend(handles=[mpatches.Patch(color='red', label='Iris-setosa'), 
                    mpatches.Patch(color='green', label='Iris-versicolor'), 
                    mpatches.Patch(color='purple', label='Iris-virginica')])

plt.show()