Easy distributed training with TensorFlow using train_and_evaluate()

Introduction

TensorFlow release 1.4 introduced the function tf.estimator.train_and_evaluate, which simplifies training, evaluation, and exporting of Estimator models. It enables distributed execution for training and evaluation, while also supporting local execution, and provides consistent behavior for across both local/non-distributed and distributed configurations.

This means that using tf.estimator.train_and_evaluate, you can run the same code on both locally and in the cloud, on different devices and using different cluster configurations, without making any code changes. A train-and-evaluate loop is automatically supported. When you're done training (or even at intermediate stages), the trained model is automatically exported in a form suitable for serving (e.g. for Cloud ML Engine online prediction, or TensorFlow serving).

In this example, we'll walk through how to use tf.estimator.train_and_evaluate with an Estimator model, and then show how easy it is to do distributed training of the model on Cloud ML Engine (CMLE), and to move between different cluster configurations with just a config tweak.

The example also includes the use of Datasets to manage our input data. This API is part of TensorFlow 1.4, and is an easier and more performant way to create input pipelines to TensorFlow models.

For our example, we'll use the The Census Income Data Set hosted by the UC Irvine Machine Learning Repository. We have hosted the data on Google Cloud Storage in a slightly cleaned form. We'll use this dataset to predict income category based on various information about a person.

The example in this notebook is a slightly modified version of this example.

Prerequisites

This example requires you to have TensorFlow 1.4 or higher installed, and Python 2.7 or 3. We strongly recommend that you install TensorFlow into a virtual environment, as described in the installation instructions.

For some of the later parts of the example, you'll also need to have Google Cloud SDK (gcloud) installed.

If you're running this notebook from colab, these prerequisites are already met.

In a later section of this example, you'll optionally need to create a GCP project (to use Cloud ML Engine). We'll point you to that info when we get there.

First step: create an Estimator

In this section, we'll do some setup and then create an Estimator model using a prebuilt Estimator subclass, DNNLinearCombinedClassifier. (More on this Estimator below).

We're using the Estimator class because it gives us built-in support for distributed training and evaluation (along with other nice features). You should nearly always use an Estimator to create your TensorFlow models. You can build a Custom Estimator if none of the prebuilt ("canned") Estimators suit your purpose.

First, copy the training and test data to a local directory and set some vars to point to the files.

If you're working in this github repo, copy the files from the parent directory as follows:


In [ ]:
%%bash
mkdir -p census_data
cp ../adult.data.csv census_data/adult.data.csv
cp ../adult.test.csv census_data/adult.test.csv

Alternately, if you're not working out of a github repo, you can download them as follows. You only need to grab the files one way or the other.


In [ ]:
%%bash
mkdir -p census_data
curl https://storage.googleapis.com/cloudml-public/census/data/adult.data.csv -o ./census_data/adult.data.csv
curl https://storage.googleapis.com/cloudml-public/census/data/adult.test.csv -o ./census_data/adult.test.csv

List the contents of the data directory as a check:


In [ ]:
%%bash
ls -l census_data
head census_data/adult.data.csv

In [ ]:
TRAIN_FILES = ['census_data/adult.data.csv']
EVAL_FILES  = ['census_data/adult.test.csv']

In [ ]:
%env TRAIN_FILE=census_data/adult.data.csv
%env EVAL_FILE=census_data/adult.test.csv

Do some imports and check your version of TensorFlow. It should be >=1.4.


In [ ]:
import argparse
import multiprocessing
import os
import time

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.utils import (
    saved_model_export_utils)
print(tf.__version__)

Now we'll begin defining our estimator. First we'll define the format of the input data.
income_bracket is our LABEL_COLUMN, meaning that this is the value we'll predict.


In [ ]:
CSV_COLUMNS = ['age', 'workclass', 'fnlwgt', 'education', 'education_num',
               'marital_status', 'occupation', 'relationship', 'race', 'gender',
               'capital_gain', 'capital_loss', 'hours_per_week',
               'native_country', 'income_bracket']
CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                       [0], [0], [0], [''], ['']]
LABEL_COLUMN = 'income_bracket'
LABELS = [' <=50K', ' >50K']

# Define the initial ingestion of each feature used by your model.
# Additionally, provide metadata about the feature.
INPUT_COLUMNS = [
    # Categorical base columns

    # For categorical columns with known values we can provide lists
    # of values ahead of time.
    tf.feature_column.categorical_column_with_vocabulary_list(
        'gender', [' Female', ' Male']),

    tf.feature_column.categorical_column_with_vocabulary_list(
        'race',
        [' Amer-Indian-Eskimo', ' Asian-Pac-Islander',
         ' Black', ' Other', ' White']
    ),
    tf.feature_column.categorical_column_with_vocabulary_list(
        'education',
        [' Bachelors', ' HS-grad', ' 11th', ' Masters', ' 9th',
         ' Some-college', ' Assoc-acdm', ' Assoc-voc', ' 7th-8th',
         ' Doctorate', ' Prof-school', ' 5th-6th', ' 10th',
         ' 1st-4th', ' Preschool', ' 12th']),
    tf.feature_column.categorical_column_with_vocabulary_list(
        'marital_status',
        [' Married-civ-spouse', ' Divorced', ' Married-spouse-absent',
         ' Never-married', ' Separated', ' Married-AF-spouse', ' Widowed']),
    tf.feature_column.categorical_column_with_vocabulary_list(
        'relationship',
        [' Husband', ' Not-in-family', ' Wife', ' Own-child', ' Unmarried',
         ' Other-relative']),
    tf.feature_column.categorical_column_with_vocabulary_list(
        'workclass',
        [' Self-emp-not-inc', ' Private', ' State-gov',
         ' Federal-gov', ' Local-gov', ' ?', ' Self-emp-inc',
         ' Without-pay', ' Never-worked']
    ),

    # For columns with a large number of values, or unknown values
    # We can use a hash function to convert to categories.
    tf.feature_column.categorical_column_with_hash_bucket(
        'occupation', hash_bucket_size=100, dtype=tf.string),
    tf.feature_column.categorical_column_with_hash_bucket(
        'native_country', hash_bucket_size=100, dtype=tf.string),

    # Continuous base columns.
    tf.feature_column.numeric_column('age'),
    tf.feature_column.numeric_column('education_num'),
    tf.feature_column.numeric_column('capital_gain'),
    tf.feature_column.numeric_column('capital_loss'),
    tf.feature_column.numeric_column('hours_per_week'),
]

# Now we'll define the unused columns-- those we won't use for this example.
# In this case, there's just one: 'fnlwgt'.
UNUSED_COLUMNS = set(CSV_COLUMNS) - {col.name for col in INPUT_COLUMNS} - {LABEL_COLUMN}
print('unused columns: %s' % UNUSED_COLUMNS)

Now we'll define a function that builds our Estimator.

We will use the DNNLinearCombinedClassifier class to create our Estimator.

This is a so-called "wide and deep" model. Wide and deep models use deep neural nets to learn high level abstractions about complex features or interactions between such features. These models then combined the outputs from the DNN with a linear regression performed on simpler features. This provides a balance between power and speed that is effective on many structured data problems.

You can read more about this model and its use here. You can learn more about using feature columns here.


In [ ]:
def build_estimator(config, embedding_size=8, hidden_units=None):
  """Build a wide and deep model for predicting income category.
  """
  (gender, race, education, marital_status, relationship,
   workclass, occupation, native_country, age,
   education_num, capital_gain, capital_loss, hours_per_week) = INPUT_COLUMNS

  # Continuous columns can be converted to categorical via bucketization
  age_buckets = tf.feature_column.bucketized_column(
      age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

  # Wide columns and deep columns.
  wide_columns = [

      # TODO phase 1: uncomment to add use of feature crosses and buckets
      
      # Interactions between different categorical features can also
      # be added as new virtual features.
#      tf.feature_column.crossed_column(
#          ['education', 'occupation'], hash_bucket_size=int(1e4)),
#      tf.feature_column.crossed_column(
#          [age_buckets, 'race', 'occupation'], hash_bucket_size=int(1e6)),

      # TODO phase 2: add a new feature cross for native_country x occupation
      
      gender,
      native_country,
      education,
      occupation,
      workclass,
      marital_status,
      relationship,
#      age_buckets,
  ]

  deep_columns = [

      # TODO phase 1: uncomment to add more numeric columns, plus
      # use of indicator and embedding feature columns

      # Use indicator columns for low dimensional vocabularies
#      tf.feature_column.indicator_column(workclass),
#      tf.feature_column.indicator_column(education),
#      tf.feature_column.indicator_column(marital_status),
#      tf.feature_column.indicator_column(gender),
#      tf.feature_column.indicator_column(relationship),
#      tf.feature_column.indicator_column(race),

      # Use embedding columns for high dimensional vocabularies
#      tf.feature_column.embedding_column(
#         native_country, dimension=embedding_size),

      # TODO phase 2: add an embedding column for occupation

      age,
#      education_num,
#      capital_gain,
#      capital_loss,
#      hours_per_week,
  ]

  return tf.estimator.DNNLinearCombinedClassifier(
      config=config,
      linear_feature_columns=wide_columns,
      dnn_feature_columns=deep_columns,
      dnn_hidden_units=hidden_units or [100, 70, 50, 25]
  )

Now, we'll create an estimator object using the function we defined, and our config values.


In [ ]:
output_dir = "census_%s" % (int(time.time()))
print(output_dir)

In [ ]:
%env OUTPUT_DIR=$output_dir

In [ ]:
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(model_dir=output_dir)

FIRST_LAYER_SIZE = 100  # Number of nodes in the first layer of the DNN
NUM_LAYERS = 4  # Number of layers in the DNN
SCALE_FACTOR = 0.7  # How quickly should the size of the layers in the DNN decay
EMBEDDING_SIZE = 8  # Number of embedding dimensions for categorical columns

estimator = build_estimator(
    embedding_size=EMBEDDING_SIZE,
    # Construct layers sizes with exponential decay
    hidden_units=[
        max(2, int(FIRST_LAYER_SIZE *
                   SCALE_FACTOR**i))
        for i in range(NUM_LAYERS)
    ],
    config=run_config
)

Define input functions (using Datasets)

To train and evaluate the estimator model, we'll need to tell it how to get its training and eval data. We'll define a function (input_fn) that knows how to generate features and labels for training or evaluation, then use that definition to create the actual train and eval input functions.

We'll use Datasets in the input_fn to access our data. This API is part of TensorFlow 1.4, and is a new way to create input pipelines to TensorFlow models. The Dataset API is much more performant than using feed_dict or the queue-based pipelines, and it's cleaner and easier to use.

(In this simple example, our datasets are too small for the use of the Datasets API to make a large difference, but with larger datasets it becomes more important).

We'll first define a couple of helper functions. parse_label_column is used to convert the label strings (in our case, ' <=50K' and ' >50K') into one-hot encodings.


In [ ]:
def parse_label_column(label_string_tensor):
  """Parses a string tensor into the label tensor
  """
  # Build a Hash Table inside the graph
  table = tf.contrib.lookup.index_table_from_tensor(tf.constant(LABELS))

  # Use the hash table to convert string labels to ints and one-hot encode
  return table.lookup(label_string_tensor)

In [ ]:
def parse_csv(rows_string_tensor):
  """Takes the string input tensor and returns a dict of rank-2 tensors."""

  # Takes a rank-1 tensor and converts it into rank-2 tensor
  # Example if the data is ['csv,line,1', 'csv,line,2', ..] to
  # [['csv,line,1'], ['csv,line,2']] which after parsing will result in a
  # tuple of tensors: [['csv'], ['csv']], [['line'], ['line']], [[1], [2]]
  row_columns = tf.expand_dims(rows_string_tensor, -1)
  columns = tf.decode_csv(row_columns, record_defaults=CSV_COLUMN_DEFAULTS)
  features = dict(zip(CSV_COLUMNS, columns))

  # Remove unused columns
  for col in UNUSED_COLUMNS:
    features.pop(col)
  return features

And now define the input function:


In [ ]:
# This function returns a (features, indices) tuple, where features is a dictionary of
# Tensors, and indices is a single Tensor of label indices.
def input_fn(filenames,
                      num_epochs=None,
                      shuffle=True,
                      skip_header_lines=0,
                      batch_size=200):
  """Generates features and labels for training or evaluation.
  """

  filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
  if shuffle:
    # Process the files in a random order.
    filename_dataset = filename_dataset.shuffle(len(filenames))
    
  # For each filename, parse it into one element per line, and skip the header
  # if necessary.
  dataset = filename_dataset.flat_map(
      lambda filename: tf.data.TextLineDataset(filename).skip(skip_header_lines))
  
  dataset = dataset.map(parse_csv)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=batch_size * 10)
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  iterator = dataset.make_one_shot_iterator()
  features = iterator.get_next()
  return features, parse_label_column(features.pop(LABEL_COLUMN))

Then, we'll use input_fn to define both the train_input and eval_input functions. We just need to pass input_fn the different source files to use for training versus evaluation. As we'll see below, these will be used to define a TrainSpec and EvalSpec used by train_and_evaluate.


In [ ]:
train_input = lambda: input_fn(
    TRAIN_FILES,
    batch_size=40
)

# Don't shuffle evaluation data
eval_input = lambda: input_fn(
    EVAL_FILES,
    batch_size=40,
    shuffle=False
)

Define training and eval specs

Now we're nearly set. We just need to define the the TrainSpec and EvalSpec used by tf.estimator.train_and_evaluate. These specify not only the input functions, but how to export our trained model.

First, we'll define the TrainSpec, which takes as an arg train_input:


In [ ]:
train_spec = tf.estimator.TrainSpec(train_input,
                                  max_steps=7500
                                  )

For our EvalSpec, we'll instantiate it with something additional -- a list of exporters, that specify how to export a trained model so that it can be used for serving.

To specify our exporter, we first define a serving input function. A serving input function should produce a ServingInputReceiver.

A ServingInputReceiver is instantiated with two arguments -- features, and receiver_tensors. The features represent the inputs to our Estimator when it is being served for prediction. The receiver_tensor represent inputs to the server. These will not necessarily always be the same — in some cases we may want to make some edits. Here's one example of that, where the inputs to the server (csv-formatted rows) include a field to be removed.

However, in our case, the inputs to the server are the same as the features input to the model.


In [ ]:
def json_serving_input_fn():
  """Build the serving inputs."""
  inputs = {}
  for feat in INPUT_COLUMNS:
    inputs[feat.name] = tf.placeholder(shape=[None], dtype=feat.dtype)

  return tf.estimator.export.ServingInputReceiver(inputs, inputs)

Then, we define an Exporter in terms of that serving input function, and pass the EvalSpec constructor a list of exporters. (We're just using one exporter here, but if you define multiple exporters, training will result in multiple saved models).


In [ ]:
exporter = tf.estimator.FinalExporter('census',
      json_serving_input_fn)
eval_spec = tf.estimator.EvalSpec(eval_input,
                                steps=100,
                                exporters=[exporter],
                                name='census-eval'
                                )

Train your model, using train_and_evaluate

Now, we have defined everything we need to train and evaluate our model, and export the trained model for serving, via a call to train_and_evaluate:


In [ ]:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

You've just trained your model and exported the result in a format that makes it easy to use it for serving! The training behavior will be consistent across both local/non-distributed and distributed configurations, thanks to train_and_evaluate.

Look at the signature of your exported model

TensorFlow ships with a CLI that allows you to inspect the signature of exported binary files. This lets us confirm the expected inputs, and shows the outputs we'll get when we run a prediction.

To do this, first locate your model:


In [ ]:
# List the directory that contains the model.
!ls -R $output_dir/export/census

Now, view the model signature.


In [ ]:
%%bash
model_dir=$(ls ${OUTPUT_DIR}/export/census)
saved_model_cli show --dir ${OUTPUT_DIR}/export/census/${model_dir} --tag serve --signature_def predict

Use the Google Cloud SDK to make predictions on your trained model.

Next, we'll use the Google Cloud SDK (gcloud) as an easy way to make predictions using our model. This section requires that you have the Google Cloud SDK (gcloud) installed.

If you're running this notebook on colab, gcloud is already installed for you.

We can use gcloud to easily make predictions using our local learned model, using
gcloud ml-engine local predict.
Note the local modifier. This command is a good way to check locally that your exported model is behaving as expected.
(Below we'll look at how to use scalable Cloud ML Engine Online Prediction instead.)

We'll use the example input in test.json. As we saw above when we built our model, we'll be predicting 'income bracket' based on the features encoded in the test.json instance.

If you're running this notebook in colab, first download the test.json file:


In [ ]:
%%bash
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/test.json

Then, use gcloud to make a local prediction.


In [ ]:
%%bash
cat test.json
model_dir=$(ls ${OUTPUT_DIR}/export/census)
gcloud ml-engine local predict \
  --model-dir=${OUTPUT_DIR}/export/census/${model_dir} \
  --json-instances=test.json

You can see how the input fields in test.json correspond to the inputs listed by the saved_model_cli command above, and how the prediction outputs correspond to the outputs listed by saved_model_cli. In this model, Class 0 indicates income <= 50k and Class 1 indicates income >50k.

Note: if you are using Python 3 and see an error, do the following.
First, make note of where your gcloud sdk is installed:


In [ ]:
%%bash
which gcloud

The result should look something like the following, with your path substituted:
<path_to_google-cloud-sdk>/bin/gcloud
Then, edit the following command to use the path to google-cloud-sdk. After you've run it, rerun the prediction command above.


In [ ]:
%%bash
rm <path_to_google-cloud-sdk>/lib/googlecloudsdk/command_lib/ml_engine/*.pyc

Using Cloud ML Engine for easy distributed training and scalable online prediction

In the previous section, we looked at how to use tf.estimator.train_and_evaluate to train and export a model, and then make predictions using the trained model.

In this section, you'll see how easy it is to use the same code — without any changes — to do distributed training on Cloud ML Engine (CMLE), thanks to the Estimator class and train_and_evaluate. Then we'll use CMLE Online Prediction to scalably serve the trained model.

This section requires that you have set up a GCP project and enabled the use of CMLE.

To run training on CMLE, we can use gcloud. We'll need to package our code so that it can be deployed, and specify the Python file to run to start the training (--module-name).

If you take a look in the trainer subdirectory of this directory, you'll see that it contains essentially the same code that's in this notebook, just packaged for deployment. trainer.task is the entry point, and when that file is run, it calls tf.estimator.train_and_evaluate.
(You can read more about how to package your code here).

We'll test training via gcloud locally first, to make sure that we have everything packaged up correctly.

Test training locally via gcloud

If you're running this notebook in colab, first download the trainer module:


In [ ]:
%%bash
mkdir -p trainer
cd trainer
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/trainer/__init__.py
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/trainer/model.py
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/trainer/task.py
cd ..
ls -l trainer

In [ ]:
output_dir = "census_%s" % (int(time.time()))
%env OUTPUT_DIR=$output_dir
! gcloud ml-engine local train --package-path trainer \
                           --module-name trainer.task \
                           -- \
                           --train-files $TRAIN_FILE \
                           --eval-files $EVAL_FILE \
                           --train-steps 1000 \
                           --job-dir $OUTPUT_DIR \
                           --eval-steps 100

Launch a distributed training job on Cloud ML Engine

Now, let's use Cloud ML Engine (CMLE) to do distributed training in the cloud. Here's where you'll use your GCP project and CMLE setup. The CMLE setup instructions included creation of a Google Cloud Storage (GCS) bucket, which we'll use below.

We'll set the training job to use the SCALE_TIER_STANDARD_1 scale spec. This gives you one 'master' instance, plus four workers and three parameter servers.

The cool thing about this is that we don't need to change our code at all to use this distributed config. Our use of the Estimator class in conjunction with the CMLE scale spec allows the distributed training config to be transparent to us -- it just works. For example, we could alternately set the --scale-tier config to use GPUs, without making any changes to our code.

Notes: Each job requires a unique name, so rerun the cell that sets the env vars below each time you submit another job, if you want to run the following more than once.
Your CMLE training job can take a few minutes to spin up, but for larger training jobs the startup time is a small percentage of the overall computation.

If you're running this notebook on your local machine, ensure you've initialized gcloud as described in the setup instructions, and configured it to use your GCP project.

If you're running this notebook on colab, run the commands in the following cell to authorize your GCP account, after editing to use your own project ID.


In [ ]:
# Run these commands if you're running this notebook on colab.
# Be sure to first edit to use your own project ID.

from google.colab import auth
auth.authenticate_user()

# **Replace the following with your project ID!**
project_id = 'your-project-id'
!gcloud config set project {project_id}

Now, we'll launch a cloud training job. Edit the following to point to your GCS bucket directory.


In [ ]:
job_name = "census_job_%s" % (int(time.time()))

# Edit the following to point to your GCS bucket directory
gcs_job_dir = "gs://your-gcs-bucket/path/%s" % job_name
# For training on CMLE, we'll use datasets stored in Google Cloud Storage (GCS) instead of local files.
%env GCS_TRAIN_FILE=gs://cloudml-public/census/data/adult.data.csv
%env GCS_EVAL_FILE=gs://cloudml-public/census/data/adult.test.csv
%env SCALE_TIER=STANDARD_1
%env JOB_NAME=$job_name
%env GCS_JOB_DIR=$gcs_job_dir

In [ ]:
# submit your distributed training job to CMLE
!gcloud ml-engine jobs submit training $JOB_NAME --scale-tier $SCALE_TIER \
    --runtime-version 1.6 --job-dir $GCS_JOB_DIR \
    --module-name trainer.task --package-path trainer/ \
    --region us-central1 \
    -- --train-steps 10000 --train-files $GCS_TRAIN_FILE --eval-files $GCS_EVAL_FILE --eval-steps 100

You can monitor the status of your job via the stream-logs command indicated above.
(You can call gcloud ml-engine jobs submit training with the --stream-logs flag to stream the output logs right away).
You can also monitor the status of your job in the Cloud Console: console.cloud.google.com/mlengine/jobs
In the logs, you'll see output from 4 worker replicas, numbered 0-3.

Scalably serve your trained model with CMLE Online Prediction

Once your job is finished, you'll find the exported model under $GCS_JOB_DIR, in addition to other data such as checkpoints. You can now deploy the exported model to Cloud ML Engine and scalably serve it for prediction, using the CMLE Online Prediction service.


In [ ]:
# Run this when the training job is finished.  Look for the directory with the 'saved_model.pb' file.
!gsutil ls -R $GCS_JOB_DIR

Based on the output above, edit the following to point to the GCS directory containing saved_model.pb.


In [ ]:
# This is just an example.
# Edit this path to point to the GCS directory that contains your saved_model.pb binary
%env MODEL_BINARY=$gcs_job_dir/export/census/<timestamp>/

Create a 'census' model on CMLE (you'll get an error if that model name already exists).


In [ ]:
!gcloud ml-engine models create census --regions us-central1
!gcloud ml-engine models list
!gcloud ml-engine versions list --model census

Next, deploy your trained model binary to CMLE as v1 of the census model. This will let you use it for prediction. (You'll get an error if version 'v1' already exists. In that case, you can use a different version name).


In [ ]:
!gcloud ml-engine versions create v1 --model census --origin $MODEL_BINARY --runtime-version 1.6

You can look at the versions of a model in the Cloud Console, as well as set the default version: console.cloud.google.com/mlengine/models

Now you can use your deployed model for prediction. We've included a file, test.json, that encodes the input instance.


In [ ]:
# Use your deployed model for prediction
!cat test.json
!gcloud ml-engine predict --model census --version v1 --json-instances test.json

Extras: Train on CMLE using a custom GPU cluster

Above, we used the STANDARD_1 scale tier to train our model. If you had wanted to train on 1 GPU, you could have used BASIC_GPU instead.

You can train on a larger GPU cluster just as easily; with gcloud, it's just a matter of defining a .yaml config file that describes your cluster, and passing that config when you submit your training job. Note that using GPUs is more expensive, so it will cost more to run this part of the example.

To see how we'd do this, let's first take a look at the config file.

If you're running this notebook on colab, you'll first need to download this config file:


In [ ]:
%%bash
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/config_custom_gpus.yaml

In [ ]:
!cat config_custom_gpus.yaml

We're using NVIDIA Tesla P100 GPUs for our master and worker nodes (which is quite overkill for this small example!) We're using standard CPU nodes for the parameter servers. You can find more info on the node types here.

We'll just run our job as before, except now we specify CUSTOM scale tier, and point to our config file.
As before, you'll need to edit the GCS bucket path in the next cell.


In [ ]:
job_name = "census_job_%s" % (int(time.time()))
# Edit the following to point to your GCS bucket directory
gcs_job_dir = "gs://your-gcs-bucket/path/%s" % job_name
%env GCS_TRAIN_FILE=gs://cloudml-public/census/data/adult.data.csv
%env GCS_EVAL_FILE=gs://cloudml-public/census/data/adult.test.csv
%env SCALE_TIER=CUSTOM
%env JOB_NAME=$job_name
%env GCS_JOB_DIR=$gcs_job_dir

In [ ]:
!gcloud ml-engine jobs submit training $JOB_NAME --scale-tier $SCALE_TIER \
    --runtime-version 1.6 --job-dir $GCS_JOB_DIR \
    --module-name trainer.task --package-path trainer/ \
    --region us-central1 --config config_custom_gpus.yaml \
    -- --train-steps 15000 --train-files $GCS_TRAIN_FILE --eval-files $GCS_EVAL_FILE --eval-steps 100

Extras: Use Hyperparameter Tuning

CMLE makes it easy to do Hyperparameter tuning. See the documentation for more info.

For this run, we'll go back to using the STANDARD_1 tier. Note that because HP tuning does multiple runs — in this case, it will be 6 — this will be more expensive than the previous single runs. As before, you'll need to edit the GCS bucket path in the next cell to point to your bucket.


In [ ]:
job_name = "census_job_%s" % (int(time.time()))
# Edit the following to point to your GCS bucket directory
gcs_job_dir = "gs://your-gcs-bucket/path/%s" % job_name
%env GCS_TRAIN_FILE=gs://cloudml-public/census/data/adult.data.csv
%env GCS_EVAL_FILE=gs://cloudml-public/census/data/adult.test.csv
%env SCALE_TIER=STANDARD_1
%env JOB_NAME=$job_name
%env GCS_JOB_DIR=$gcs_job_dir

If you're running this notebook on colab, you'll first need to download this config file:


In [ ]:
%%bash
wget https://raw.githubusercontent.com/amygdala/tensorflow-workshop/master/workshop_sections/wide_n_deep/tf_train_and_evaluate/hptuning_config.yaml

In [ ]:
# We'll use the `hptuning_config.yaml` file for this run.
!cat hptuning_config.yaml

In [ ]:
!gcloud ml-engine jobs submit training $JOB_NAME --scale-tier $SCALE_TIER \
    --runtime-version 1.6 --job-dir $GCS_JOB_DIR \
    --module-name trainer.task --package-path trainer/ \
    --region us-central1 --config hptuning_config.yaml \
    -- --train-steps 15000 --train-files $GCS_TRAIN_FILE --eval-files $GCS_EVAL_FILE --eval-steps 100

You can easily look at the results in the Cloud Console — click on a job to see the its details, including the HP trial outcomes. You can also see information about each trial reflected in the job logs. The checkpoints and export for each trial are saved to separate subdirectories organized by trial number under your job dir.


Copyright 2018 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


In [ ]: