TensorFlow versions: 1.2.*, 1.3
Estimators provide a high-level interface to TensorFlow. They are heavily inspired by the scikit-learn API.
Wrapping your model up in an Estimator gives you off-the-shelf between-graph replication for your training jobs through Experiment (which encapsulates the process of training a model) and learn_runner (which actually handles the replication). This makes it particularly easy to take advantage of the scaling functionality on Cloud Machine Learning Engine, for example.
Using the Estimator interface makes it easy to save and restore your trained models, which has generally been something of a stumbling block when it comes to using TensorFlow.
TensorFlow has a number of canned Estimators -- LinearClassifier, LinearRegressor, DNNClassifier, DNNRegressor, DNNLinearCombinedClassifier, DNNLinearCombinedRegressor. Moreover, it is recommended that model definitions be wrapped up inside estimators. This means that more and more people are going to be using estimators. Defining your models within Estimators allow them to be dropped in with minimal fuss in place of existing Estimators within particular workflows. This allows not just you, but all potential users of your estimator to iterate painlessly.
If you are interested in (eventually) training your model on TPUs, wrapping it in an Estimator will go very far towards making sure that it is truly leveraging the power of the TPU. This is because TensorFlow uses the interface to determine optimal placement of operations on devices.
TensorFlow currently contains two different Estimator interfaces:
tf.estimator.Estimator
is the current interface, and is the one that we will be using in this notebook. However, we will also be using some utilities that make it easier to work with estimators (like experiments), which have not yet been ported over to the tf.estimator
module. This will require some rather hacky modifications to certain tf.contrib.learn
concepts. As TensorFlow improves the situation with tf.estimator
, we will reflect those changes here.
We will describe the interface and implement a lightweight Estimator to perform logistic regression.
In this initial example, we will be storing all our data in memory. However, this is not the typical use case for TensorFlow training. So we will, in a bonus section, extend our example to read data from files. This extension is intended to highlight good practices, and the intention is that you be able to paste the input function we produce into your own code and instantly benefit from its performance advantages over manually feeding data into your model.
We will set up an Experiment to manage model training and export.
In [ ]:
from __future__ import print_function
import numpy as np
import tensorflow as tf
This class defines the higher-level TensorFlow estimator interface.
If you want to construct your own estimator, you must define a function which constructs the appropriate TensorFlow model (graph + operations). The estimator is then constructed by initializing a tf.estimator.Estimator
by passing your model function as the model_fn
argument and, optionally:
model_dir
-- Directory in which to save information about the model.
config
-- RunConfig to be used when running the estimator, describing properties extrinsic to the model like the cluster specification, checkpointing strategy, etc.
params
-- Dictionary of model hyperparameters.
The bulk of the work is in implementing your model_fn
.
There are two different versions of RunConfig:
We will use tf.contrib.learn.RunConfig
in this notebook because tf.estimator.RunConfig
breaks compatibility with tf.contrib.learn.Experiment, which is a very useful utility that we will discuss below.
features
-- Input to the model. Either a single tensor or a dictionary of labelled (by the string key) tensors.
labels
-- Labels corresponding to the provided features (in PREDICT mode, this is guaranteed to be None
). This can be a single tensor or, in the case of a multi-headed model (multiple heads at output layer), this can be a dictionary of tensors. (In the multi-headed model case, it is not clear what the keys have to be. Presumably they have to match the head variable names?)
mode
-- The mode in which the model is to be run.
params
-- This is where the hyperparameters passed to the Estimator
upon instantiation are used. They are used transparently.
config
-- RunConfig
passed to the Estimator
upon instantiation is made available to the model_fn
here.
An EstimatorSpec. This is what allows constructed model to play nicely with the TensorFlow Estimator
interface.
An Estimator
has three methods that qualify as a run: evaluate
, predict
, and train
. Each of these methods accepts as its first argument an input function, input_fn
.
The input_fn
should be callable with no arguments and should return an ordered pair of features
and labels
with each being either a tensor or a string-keyed dictionary of tensors.
When you call the evaluate
, predict
, or train
method of an estimator:
input_fn
is called to produce features
and labels
.
These are passed to the model_fn
the Estimator
was instantiated with.
The resulting model instance has its variables populated from a checkpoint in the model directory that was specified to the Estimator
upon instantiation. If no checkpoint is specified, the latest one is used. If no checkpoint exists, the variables are initialized.
Depending on the mode, the appropriate operations are run in a TensorFlow MonitoredSession. The results of these operations are yielded back to the caller.
Our classifier will perform a logistic regression on a single feature. Let us say that there are two latent classes, $0$ and $1$, corresponding to this feature.
Within class $0$, the values for the feature are distributed as the square of a normally distributed random variable with mean $m$ and variance $\sigma^2$.
Within class $1$, the values for the feature are exponentially distributed with rate $\lambda$.
In [ ]:
def sample_from_normal_distribution(n, mu, sigma):
return np.square(np.random.normal(mu, sigma, n))
In [ ]:
MU = 10
SIGMA = 0.25
In [ ]:
example_0 = sample_from_normal_distribution(20, MU, SIGMA)
example_0
In [ ]:
def sample_from_exponential_distribution(n, rate):
return np.random.exponential(float(1)/rate, n)
In [ ]:
RATE = 1
In [ ]:
example_1 = sample_from_exponential_distribution(20, RATE)
example_1
We can use these samplers to make input functions for our estimator.
It is generally good practice to wrap up our inputs into a TensorFlow queue to produce the appropriate tensors in our graph, so let us do that. The nice thing about this is that we can use the queue itself to shuffle up our data rather than having to do it ourselves.
In the generate_input_fn
below, we produce our input data as a numpy ndarray
. We then use the tf.train.input_producer as the entry point into the TensorFlow framework.
The BATCH_SIZE
parameter determines how many data points we feed to our estimator at a time.
In [ ]:
BATCH_SIZE=1000
In [ ]:
def generate_input_fn(distributions, num_samples, num_epochs=None, shuffle=True, batch_size=BATCH_SIZE):
assert len(distributions) == len(num_samples)
components = len(distributions)
samples = np.expand_dims(np.concatenate([distributions[i](num_samples[i]) for i in range(components)]), 1)
labels = np.concatenate([np.full((num_samples[i], 1), i, dtype=np.float64) for i in range(components)])
stack = np.stack([samples, labels], axis=1)
def input_fn():
q = tf.train.input_producer(stack, shuffle=shuffle, num_epochs=num_epochs)
top = q.dequeue_up_to(batch_size)
# Split up the features and the labels (in that order).
raw_features, labels = tuple(tf.unstack(top, axis=1))
# Wrap the features inside a dictionary
return {'inputs': raw_features}, labels
return input_fn
Let us turn our attention to the estimator definition.
As we discussed above, everything hinges on our model function definition.
In [ ]:
MODES = tf.estimator.ModeKeys
def model_fn(features, labels, mode, params=None, config=None):
print(features, labels)
# Default learning rate is 0.1
if params is None:
params = {'learning_rate': 0.1}
inputs = features.get('inputs')
if inputs is None:
raise ValueError('Input "features" did not define a "feature" key')
# We logistically estimate the probability of each sample belonging to the class labelled 1
weight = tf.Variable(tf.zeros([1,1], dtype=tf.float64), name='weight')
bias = tf.Variable(tf.zeros([1,1], dtype=tf.float64), name='bias')
logit = tf.add(tf.matmul(inputs, weight), bias, name='logit')
logistic = tf.sigmoid(logit, name='logistic')
loss = None
train_op = None
if mode in (MODES.TRAIN, MODES.EVAL):
loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=labels, name='loss')
)
tf.summary.scalar('loss', loss)
if mode == MODES.TRAIN:
learning_rate = params.get('learning_rate')
global_step = tf.train.get_global_step()
train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(
loss, global_step=global_step)
prediction_output = tf.estimator.export.PredictOutput({'class_1_probability': logistic})
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=logistic,
loss=loss,
train_op=train_op,
export_outputs={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_output}
)
With the model_fn
defined, we can now instantiate an estimator.
We will store details about the model in MODEL_DIR
. You can change the location by updating the variable below. The cool thing about it is that MODEL_DIR
can even be a Google Cloud Storage bucket.
We will use the LEARNING_RATE
variable to specify a value for the learning rate of the trainer used by our estimator. Please edit this as you please.
In this tutorial, you will have to specify a few different directories for use with various estimators and (later) to store some data. You should specify fresh directories. To make sure that everything is set up as intended, the function below will accept a directory path from you and:
Ensure that a fresh directory is created along that path.
Raise errors if there is already something at that path.
Act as a marker for the cells in which you should add directories. Simply do a CTRL+F
for "setup_directory" and (excluding this cell and the one below it) you will find the cells that you should modify.
In [ ]:
import os
def setup_directory(directory_path):
os.makedirs(directory_path)
return directory_path
In [ ]:
MODEL_DIR = setup_directory('DIRECTORY PATH GOES HERE - IT SHOULD NOT POINT TO AN EXISTING DIRECTORY')
In [ ]:
LEARNING_RATE = 0.05
In [ ]:
run_config = tf.estimator.RunConfig().replace(save_summary_steps=10)
regressor = tf.estimator.Estimator(model_fn, model_dir=MODEL_DIR, params={'learning_rate': LEARNING_RATE},
config=run_config)
Before we move onto training, evaluation, and prediction, it is worth pausing to discuss how the estimator allows for model reuse (for example, making use of previous training). Generally, in TensorFlow, there are two different means of storing and restoring the state of a TensorFlow graph:
SavedModelBuilder -- This is particularly useful when you want to create a prediction service using your trained model.
The beauty of using the Estimator is that it handles checkpointing by itself and completely under the hood, so our only relationship to it here was specifying the MODEL_DIR
into which we wanted our checkpoints (and metagraph information) to be stored. Checkpoints are automatically stored and loaded when you call any of the Estimator
methods.
Moreover, note that the EstimatorSpec that we return from our model_fn
contains an export_outputs
dictionary. This dictionary is used to provide information to the SavedModelBuilder
when we want to export our graph using the Estimator's export_saved_model()
method. We will see this in action after we cover training, evaluation, and prediction.
In [ ]:
distributions = [lambda n: sample_from_normal_distribution(n, MU, SIGMA),
lambda n: sample_from_exponential_distribution(n, RATE)]
To begin with, let us train our estimator on a data set of 1000 samples from each distribution:
In [ ]:
regressor.train(generate_input_fn(distributions, [100000, 100000]), steps=200)
In [ ]:
regressor.evaluate(generate_input_fn(distributions, [1000, 1000]), steps=20)
In [ ]:
list(regressor.predict(generate_input_fn(distributions, [1, 1], num_epochs=1, shuffle=False)))
The Estimator exposes its export_savedmodel method, which is a very convenient way to manage a TensorFlow model, especially for serving purposes (either on your own or using the Cloud ML Engine Prediction service).
When you use this export_savedmodel
method, you will specify the following things:
export_dir_base
-- This is the base directory into which all the exports should go. The exports will be stored under a subdirectory of export_dir_base
labelled by the timestamp at which the export took place.
serving_input_receiver_fn
-- A function which creates a ServingInputReceiver object. This object specifies the inputs expected by the prediction server and how the features to be passed to the Estimator should be extracted from them.
assets_extra
-- Use this only if you need to provide extra assets with your SavedModel
. We won't demonstrate this here. Set to None
by default.
as_text
-- If you set this to True
, the SavedModel
will be exported in text format rather than as a .proto
file. Set to False
by default.
checkpoint_path
-- Here, you can optionally specify exactly which checkpoint you want to have exported from your MODEL_DIR
. If you leave this with its default value of None
, the most recent checkpoint will be used.
We are pretty far down the rabbit hole now, but let's just focus on the interfaces before us. This is supposed to be a function with no arguments, which produces a ServingInputReceiver. A ServingInputReceiver
is instantiated with two arguments -- features
, and receiver_tensors
. The features
represent the inputs to our Estimator when it is being served. The receiver_tensors
represent inputs to the server.
In our case, we will expect inputs to any server serving our model to be exactly the features that we feed the model.
In [ ]:
def serving_input_receiver_fn():
feature_tensor = tf.placeholder(tf.float64, [None, 1])
return tf.estimator.export.ServingInputReceiver({'inputs': feature_tensor}, {'inputs': feature_tensor})
In [ ]:
BASE_EXPORT_DIR = setup_directory('DIRECTORY PATH GOES HERE - IT SHOULD NOT POINT TO AN EXISTING DIRECTORY')
In [ ]:
regressor.export_savedmodel(BASE_EXPORT_DIR, serving_input_receiver_fn)
(A brief aside on serving TensorFlow models.)
Your best option if you want to serve a TensorFlow model that you have trained is to use TensorFlow Serving. This is basically an application which serves your TensorFlow models through a gRPC API. It makes use of the SavedModel
s exported by the Estimator's export_savedmodel
method to do so.
The ML Engine on Google Cloud Platform offers a prediction service which manages TensorFlow Serving for you. All you have to do is provide it with versioned exports of your TensorFlow model, and it handles the serving portion for you. If you are interested in this, check out the ML Engine Prediction Overview.
For an example of how to use TensorFlow Serving in the real world, have a look at this blog post by Wai Chee Yau, who actually added multiple version serving functionality to the module!
(Note: TensorFlow Serving does not integrate directly with Google Cloud Storage in the same way that TensorFlow does. You will have to download your exported SavedModels to the server on which you intend to run TensorFlow Serving.)
tf.contrib.learn.Experiment is a very useful aid in managing training and evaluation of a model. The class was originally designed for use with tf.contrib.learn.Estimator
, but has been extended to work also with tf.estimator.Estimator
. There are some nuances to making it work with tf.estimator.Estimator
, however. We shall cover them in this section.
Unfortunately, the integration with tf.estimator.Estimator
is buggy. The source of the bug is essentially that the Experiment
still wants the Estimator
to be instantiated with tf.contrib.learn.RunConfig whereas a tf.estimator.Estimator
is typically instantiated with a tf.estimator.RunConfig.
Thankfully, the attributes on a tf.estimator.RunConfig
instance are a subset of those on a tf.contrib.learn.RunConfig
instance, so we can simply use a tf.contrib.learn.RunConfig
in place of what we had before. In the case of our regressor
, we were using the default RunConfig. To make a regressor which works with Experiment
, we simply add the config
argument as below:
In [ ]:
NEW_MODEL_DIR = setup_directory('DIRECTORY PATH GOES HERE - IT SHOULD NOT POINT TO AN EXISTING DIRECTORY')
In [ ]:
experiment_compatible_regressor = tf.estimator.Estimator(model_fn, model_dir=NEW_MODEL_DIR, params={'learning_rate': LEARNING_RATE},
config=tf.contrib.learn.RunConfig())
Another notion to consider when dealing with Experiment
s is that of ExportStrategy. These dictate to the Experiment
how to export various versions of the estimator for serving with even potentially different signatures as it performs (potentially continuous) training and evaluation.
When working directly with an estimator, we would manually make calls to the export_savedmodel
method with specific checkpoints and with specific serving_input_receiver_fn
s in order to ahieve a similar effect.
Now, there is a pretty useful utility for generating export strategies when you are using tf.contrib.learn.Estimator
-- tf.contrib.learn.make_export_strategy. However, this utility will not work for us because it requires an old-style serving_input_fn
, which returns an InputFnOps object as compared to a ServingInputReceiver object.
Nevern mind, though, because it is quite easy for us to define our own export strategies very directly. The first thing to do is to export our export logic into an export_fn
as follows:
In [ ]:
def export_fn(estimator, export_path, checkpoint_path=None):
return estimator.export_savedmodel(export_path,
serving_input_receiver_fn,
checkpoint_path=checkpoint_path)
For more information on export_fn
s, you can refer directly to the ExportStrategy docs.
The TL;DR version of it is that we can create an export strategy for use with our experiment as follows:
In [ ]:
export_strategy = tf.contrib.learn.ExportStrategy('default', export_fn)
We can define our experiment:
In [ ]:
experiment = tf.contrib.learn.Experiment(experiment_compatible_regressor,
train_input_fn=generate_input_fn(distributions, [10000, 10000]),
eval_input_fn=generate_input_fn(distributions, [10, 10]),
train_steps=100,
export_strategies=export_strategy
)
By specifying train_steps=100
, we are saying that our experiment will only consist of training up to a global_step
of 100. This is very significantly distinct from saying that we will train for 100 more global_step
s. If we had already surpassed global_step
100, running train on this experiment would achieve no results!
If we wanted to train forever, we would pass train_steps=None
, which is actually the default.
Now, we can train our estimator, evaluate it, and export the saved model for serving all with one command:
In [ ]:
experiment.train_and_evaluate()
Note: The saved models generated by train_and_evaluate
are stored into the exports
subdirectory of the regressor's model_dir
. If you want to explicitly provide some other location, you can make the appropriate change to the export_fn
above.
As a sanity check, let us check on the performance on this new regressor:
In [ ]:
list(experiment_compatible_regressor.predict(
generate_input_fn(distributions, [1, 1], num_epochs=1, shuffle=False)
))
In the example above, we were generating all of our data (training, evaluation, and prediction) in memory. Although pedagogically convenient, this is practically useless to us in our everyday lives.
In this section, we will rewire the whole example to read its training and evaluation data from disk. The beauty of it is that we will only need to make some very minor changes to our input_fn
generator.
First, we should produce our input files. We will write labelled data to 2-column CSV files in which the first column contains the label and the second column contains the number that was sampled from the distribution corresponding to the label.
We will create NUM_FILES
in DATA_DIR
, each one having SAMPLES_PER_FILE
points of training data.
In [ ]:
DATA_DIR = setup_directory('DIRECTORY PATH GOES HERE - IT SHOULD NOT POINT TO AN EXISTING DIRECTORY')
In [ ]:
NUM_FILES = 100
In [ ]:
SAMPLES_PER_FILE = 1000
In [ ]:
def create_labelled_data_file(filename, distributions, labels):
labels_list = list(labels)
samples_list = [item[0] for item in map(lambda i: distributions[i](1), labels)]
rows = [','.join([str(entry) for entry in row]) for row in zip(labels_list, samples_list)]
content = '\n'.join(rows)
with open(filename, 'w') as f:
f.write(content)
In [ ]:
filenames = tuple(['{}/tf-estimator-data-{}.csv'.format(DATA_DIR, i) for i in range(NUM_FILES)])
In [ ]:
for i in range(NUM_FILES):
# Generate SAMPLES_PER_FILE labels at random with each label having probability 0.5
labels = (np.random.uniform(size=SAMPLES_PER_FILE) > 0.5).astype(int)
create_labelled_data_file(filenames[i], distributions, labels)
Now that the files have been created, we can define an input_fn
generator which produces an Estimator input function that reads data from those files.
This construction basically follows the Reading Data section of the tensorflow.org Programmer's Guide. However, the comments in the code block below may be of some value to you if you could not go smoothly through the steps in the Reading Data guide.
In [ ]:
def generate_file_input_fn(filenames, num_epochs=None, shuffle=True):
def input_fn():
filename_queue = tf.train.string_input_producer(filenames,
num_epochs,
shuffle,
capacity=len(filenames))
file_reader = tf.TextLineReader()
_, csv = file_reader.read(filename_queue)
# The record_defaults specify not only the defaults, but they also tell the decoder what
# schema to expect in the rows.
#
# The valid dtypes for record_defaults are tf.int32, tf.int64, tf.string, tf.float32.
# Since tf.float64 is not allowed and since our model_fn expects the inputs and the labels
# to have dtype=tf.float64, we will have to recast later.
record_defaults = [tf.constant([0], dtype=tf.float32),
tf.constant([0], dtype=tf.float32)]
labels, raw_features = tf.decode_csv(csv, record_defaults=record_defaults)
# Not only do we have to recast our features and labels to tf.float64, but we also
# have to explicitly specify the shape of each column. If we did not do this, TensorFlow
# would error out at graph construction time because it would not be able to validate
# the shapes of the tensors used in its operations.
#
# When designing your own estimator for release, it would be advisable to put this kind
# of code in your model_fn, making it easier for your users to provide custom input_fns.
processed_labels = tf.reshape(tf.cast(labels, tf.float64), [-1, 1])
processed_features = tf.reshape(tf.cast(raw_features, tf.float64), [-1, 1])
return {'inputs': processed_features}, processed_labels
return input_fn
Let us now use this to continue training our experiment_compatible_regressor
. All we have to do is make a new Experiment
!
(This really highlights the utility of tf.contrib.learn.Experiment
.)
In [ ]:
train_from_file_experiment = tf.contrib.learn.Experiment(experiment_compatible_regressor,
train_input_fn=generate_file_input_fn(filenames),
eval_input_fn=generate_input_fn(distributions, [10, 10]),
train_steps=10000,
export_strategies=export_strategy
)
In [ ]:
train_from_file_experiment.train_and_evaluate()
In [ ]:
list(experiment_compatible_regressor.predict(
generate_input_fn(distributions, [1, 1], num_epochs=1, shuffle=False)
))
And we're done!
In [ ]:
print('REMOVE THESE DIRECTORIES:\n1. {}\n2. {}\n3. {}\n4. {}'.format(MODEL_DIR, NEW_MODEL_DIR, BASE_EXPORT_DIR, DATA_DIR))
Copyright 2017 Google, Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.