ML with TensorFlow Extended (TFX) -- Part 3

The puprpose of this tutorial is to show how to do end-to-end ML with TFX libraries on Google Cloud Platform. This tutorial covers:

  1. Data analysis and schema generation with TF Data Validation.
  2. Data preprocessing with TF Transform.
  3. Model training with TF Estimator.
  4. Model evaluation with TF Model Analysis.

This notebook has been tested in Jupyter on the Deep Learning VM.

Setup Cloud environment


In [ ]:
import tensorflow as tf
import tensorflow_data_validation as tfdv
import tensorflow_transform as tft

print('TF version: {}'.format(tf.__version__))
print('TFT version: {}'.format(tft.__version__))
print('TFDV version: {}'.format(tfdv.__version__))

In [ ]:
PROJECT = 'cloud-training-demos'    # Replace with your PROJECT
BUCKET = 'cloud-training-demos-ml'  # Replace with your BUCKET
REGION = 'us-central1'              # Choose an available region for Cloud MLE

import os

os.environ['PROJECT'] = PROJECT
os.environ['BUCKET'] = BUCKET
os.environ['REGION'] = REGION

In [ ]:
%%bash
gcloud config set project $PROJECT
gcloud config set compute/region $REGION

## ensure we predict locally with our current Python environment
gcloud config set ml_engine/local_python `which python`

UCI Adult Dataset: https://archive.ics.uci.edu/ml/datasets/adult

Predict whether income exceeds $50K/yr based on census data. Also known as "Census Income" dataset.


In [ ]:
DATA_DIR='gs://cloud-samples-data/ml-engine/census/data'

In [ ]:
import os

TRAIN_DATA_FILE = os.path.join(DATA_DIR, 'adult.data.csv')
EVAL_DATA_FILE = os.path.join(DATA_DIR, 'adult.test.csv')
!gsutil ls -l $TRAIN_DATA_FILE
!gsutil ls -l $EVAL_DATA_FILE

In [ ]:
HEADER = ['age', 'workclass', 'fnlwgt', 'education', 'education_num',
               'marital_status', 'occupation', 'relationship', 'race', 'gender',
               'capital_gain', 'capital_loss', 'hours_per_week',
               'native_country', 'income_bracket']

TARGET_FEATURE_NAME = 'income_bracket'
TARGET_LABELS = [' <=50K', ' >50K']
WEIGHT_COLUMN_NAME = 'fnlwgt_scaled' # note that you changes the column name in tft

RAW_SCHEMA_LOCATION = 'raw_schema.pbtxt'

3. Model Training

For training the model, we use TF Estimators APIs to train a premade DNNClassifier. We perform the following:

  1. Load the transform schema
  2. Use the transform schema to parse TFRecords in input_fn
  3. Use the transform schema to create feature columns
  4. Create a premade DNNClassifier
  5. Train the model
  6. Implement the serving_input_fn and apply the transform logic
  7. Export and test the saved model.

3.1 Load transform output


In [ ]:
PREPROC_OUTPUT_DIR = 'gs://{}/census/tfx'.format(BUCKET)  # from 02_transform.ipynb
TRANSFORM_ARTIFACTS_DIR = os.path.join(PREPROC_OUTPUT_DIR,'transform')
TRANSFORMED_DATA_DIR = os.path.join(PREPROC_OUTPUT_DIR,'transformed')
!gsutil ls $TRANSFORM_ARTIFACTS_DIR
!gsutil ls $TRANSFORMED_DATA_DIR

In [ ]:
transform_output = tft.TFTransformOutput(TRANSFORM_ARTIFACTS_DIR)

3.2 TFRecords Input Function


In [ ]:
def make_input_fn(tfrecords_files, 
  batch_size, num_epochs=1, shuffle=False):

  def input_fn():
    dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=tfrecords_files,
      batch_size=batch_size,
      features=transform_output.transformed_feature_spec(),
      label_key=TARGET_FEATURE_NAME,
      reader=tf.data.TFRecordDataset,
      num_epochs=num_epochs,
      shuffle=shuffle
    )
    return dataset

  return input_fn

In [ ]:
make_input_fn(TRANSFORMED_DATA_DIR+'/train*.tfrecords', 2, shuffle=False)()

3.3 Create feature columns


In [ ]:
import math

def create_feature_columns():

  feature_columns = []
  transformed_features = transform_output.transformed_metadata.schema._schema_proto.feature

  for feature in transformed_features:

    if feature.name in [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]:
      continue

    if hasattr(feature, 'int_domain') and feature.int_domain.is_categorical:
      vocab_size = feature.int_domain.max + 1
      feature_columns.append(
        tf.feature_column.embedding_column(
          tf.feature_column.categorical_column_with_identity(
            feature.name, num_buckets=vocab_size),
            dimension = int(math.sqrt(vocab_size))))
    else:
      feature_columns.append(
        tf.feature_column.numeric_column(feature.name))

  return feature_columns

In [ ]:
create_feature_columns()

3.4 Instantiate and Estimator


In [ ]:
def create_estimator(params, run_config):
    
  feature_columns = create_feature_columns()

  estimator = tf.estimator.DNNClassifier(
    weight_column=WEIGHT_COLUMN_NAME,
    label_vocabulary=TARGET_LABELS,
    feature_columns=feature_columns,
    hidden_units=params.hidden_units,
    config=run_config
  )

  return estimator

3.5 Implement train and evaluate experiment


In [ ]:
from datetime import datetime

def run_experiment(estimator, params, run_config, resume=False):
  
  tf.logging.set_verbosity(tf.logging.INFO)

  if not resume: 
    if tf.gfile.Exists(run_config.model_dir):
      print("Removing previous artifacts...")
      tf.gfile.DeleteRecursively(run_config.model_dir)
  else:
    print("Resuming training...")

  train_spec = tf.estimator.TrainSpec(
      input_fn = make_input_fn(
          TRANSFORMED_DATA_DIR+'/train*.tfrecords',
          batch_size=params.batch_size,
          num_epochs=None,
          shuffle=True
      ),
      max_steps=params.max_steps
  )

  eval_spec = tf.estimator.EvalSpec(
      input_fn = make_input_fn(
          TRANSFORMED_DATA_DIR+'/eval*.tfrecords',
          batch_size=params.batch_size,     
      ),
      start_delay_secs=0,
      throttle_secs=0,
      steps=None
  )
  
  time_start = datetime.utcnow() 
  print("Experiment started at {}".format(time_start.strftime("%H:%M:%S")))
  print(".......................................")
  
  tf.estimator.train_and_evaluate(
    estimator=estimator,
    train_spec=train_spec, 
    eval_spec=eval_spec)

  time_end = datetime.utcnow() 
  print(".......................................")
  print("Experiment finished at {}".format(time_end.strftime("%H:%M:%S")))
  print("")
  
  time_elapsed = time_end - time_start
  print("Experiment elapsed time: {} seconds".format(time_elapsed.total_seconds()))

3.5 Run experiment


In [ ]:
MODELS_LOCATION = 'models/census'
MODEL_NAME = 'dnn_classifier'
model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME)
os.environ['MODEL_DIR'] = model_dir

params = tf.contrib.training.HParams()
params.hidden_units = [128, 64]
params.dropout = 0.15
params.batch_size =  128
params.max_steps = 1000

run_config = tf.estimator.RunConfig(
    tf_random_seed=19831006,
    save_checkpoints_steps=200, 
    keep_checkpoint_max=3, 
    model_dir=model_dir,
    log_step_count_steps=10
)

In [ ]:
estimator = create_estimator(params, run_config)
run_experiment(estimator, params, run_config)

3.6 Export the model for serving


In [ ]:
tf.logging.set_verbosity(tf.logging.ERROR)

def make_serving_input_receiver_fn():
  from tensorflow_transform.tf_metadata import schema_utils

  source_raw_schema = tfdv.load_schema_text(RAW_SCHEMA_LOCATION)
  raw_feature_spec = schema_utils.schema_as_feature_spec(source_raw_schema).feature_spec
  raw_feature_spec.pop(TARGET_FEATURE_NAME)
  if WEIGHT_COLUMN_NAME in raw_feature_spec:
    raw_feature_spec.pop(WEIGHT_COLUMN_NAME)


  # Create the interface for the serving function with the raw features
  raw_features = tf.estimator.export.build_parsing_serving_input_receiver_fn(raw_feature_spec)().features

  receiver_tensors = {feature: tf.placeholder(shape=[None], dtype=raw_features[feature].dtype) 
    for feature in raw_features
  }

  receiver_tensors_expanded = {tensor: tf.reshape(receiver_tensors[tensor], (-1, 1)) 
    for tensor in receiver_tensors
  }

  # Apply the transform function 
  transformed_features = transform_output.transform_raw_features(receiver_tensors_expanded)

  return tf.estimator.export.ServingInputReceiver(
    transformed_features, receiver_tensors)

In [ ]:
export_dir = os.path.join(model_dir, 'export')

if tf.gfile.Exists(export_dir):
    tf.gfile.DeleteRecursively(export_dir)
        
estimator.export_savedmodel(
    export_dir_base=export_dir,
    serving_input_receiver_fn=make_serving_input_receiver_fn
)

In [ ]:
%%bash

saved_models_base=${MODEL_DIR}/export/
saved_model_dir=${MODEL_DIR}/export/$(ls ${saved_models_base} | tail -n 1)
echo ${saved_model_dir}
saved_model_cli show --dir=${saved_model_dir} --all

3.7 Try out saved model


In [ ]:
export_dir = os.path.join(model_dir, 'export')
tf.gfile.ListDirectory(export_dir)[-1]
saved_model_dir = os.path.join(export_dir, tf.gfile.ListDirectory(export_dir)[-1])
print(saved_model_dir)
print()

predictor_fn = tf.contrib.predictor.from_saved_model(
    export_dir = saved_model_dir,
    signature_def_key="predict"
)

input = {
        'age': [34.0],
        'workclass': ['Private'],
        'education': ['Doctorate'],
        'education_num': [10.0],
        'marital_status': ['Married-civ-spouse'],
        'occupation': ['Prof-specialty'],
        'relationship': ['Husband'],
        'race': ['White'],
        'gender': ['Male'],
        'capital_gain': [0.0], 
        'capital_loss': [0.0], 
        'hours_per_week': [40.0],
        'native_country':['Mexico']
}

print(input)
print()
output = predictor_fn(input)
print(output)

3.8 Deploy model to Cloud ML Engine


In [ ]:
#%%bash
#MODEL_NAME="census"
#MODEL_VERSION="v1"
#MODEL_LOCATION=$(gsutil ls gs://${BUCKET}/census/dnn_classifier/export/exporter | tail -1)
#gcloud ml-engine models create ${MODEL_NAME} --regions $REGION
#gcloud ml-engine versions create ${MODEL_VERSION} --model ${MODEL_NAME} --origin ${MODEL_LOCATION} --runtime-version 1.13

3.9 Export evaluation saved model


In [ ]:
HEADER_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                   [0], [0], [0], [''], ['']]

def make_eval_input_receiver_fn():
  receiver_tensors = {'examples': tf.placeholder(dtype=tf.string, shape=[None])}
  columns = tf.decode_csv(receiver_tensors['examples'], record_defaults=HEADER_DEFAULTS)
  features = dict(zip(HEADER, columns))
  print(features)

  for feature_name in features:
    if features[feature_name].dtype == tf.int32:
      features[feature_name] = tf.cast(features[feature_name], tf.int64)
    features[feature_name] = tf.reshape(features[feature_name], (-1, 1))

  transformed_features = transform_output.transform_raw_features(features)
  features.update(transformed_features)

  return tfma.export.EvalInputReceiver(
    features=features,
    receiver_tensors=receiver_tensors,
    labels=features[TARGET_FEATURE_NAME]
    )

In [ ]:
import tensorflow_model_analysis as tfma
eval_model_dir = os.path.join(model_dir, "export/evaluate")
if tf.gfile.Exists(eval_model_dir):
    tf.gfile.DeleteRecursively(eval_model_dir)

tfma.export.export_eval_savedmodel(
        estimator=estimator,
        export_dir_base=eval_model_dir,
        eval_input_receiver_fn=make_eval_input_receiver_fn
)

License

Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


This is not an official Google product. The sample code provided for educational purposes only.