This tutorial is an end-to-end walkthrough of training a TensorFlow Ranking (TF-Ranking) neural network model which incorporates sparse textual features.
A Python script version of this code is available here. The script version supports flags for hyperparameters, and advanced use-cases like Document Interaction Networks.
TF-Ranking is a library for solving large scale ranking problems using deep learning. TF-Ranking can handle heterogeneous dense and sparse features, and scales up to millions of data points. For more details, please read the technical paper published on arXiv.
|
Learning to Rank (LTR) deals with learning to optimally order a list of examples, given some context. For instance, in search applications, examples are documents and context is the query.
These models are usually trained using user relevance feedback, which can be explicit (human ratings) or implicit (clicks).
This tutorial demonstrates how to build ranking estimators over sparse features, such as textual data. Textual data is prevalent in several settings for ranking, and plays a significant role is relevance judgment by a user.
In three different LTR scenarios, the following textual features provide useful signals for ranking:
Hence it is important for LTR models to effectively incorporate textual features.
For the purpose of this tutorial, we consider ranking problem over ANTIQUE, a question-answering dataset. Given a query, and a list of answers, the objective it to maximize a rank related metric (say NDCG).
ANTIQUE is a publicly available dataset for open-domain non-factoid question answering, collected over Yahoo! answers.
Each question has a list of answers, whose relevance are graded on a scale of 1-5.
The list size can vary depending on the query, so we use a fixed "list size" of 50, where the list is either truncated or padded with dummy values.
This dataset is a suitable one for learning-to-rank scenario. The dataset is split into 2206 queries for training and 200 queries for testing. For more details, please read the tehcnical paper on arXiv.
Download training, test data and vocabulary file.
In [0]:
!wget -O "/tmp/vocab.txt" "http://ciir.cs.umass.edu/downloads/Antique/tf-ranking/vocab.txt"
!wget -O "/tmp/train.tfrecords" "http://ciir.cs.umass.edu/downloads/Antique/tf-ranking/ELWC/train.tfrecords"
!wget -O "/tmp/test.tfrecords" "http://ciir.cs.umass.edu/downloads/Antique/tf-ranking//ELWC/test.tfrecords"
Next, we discuss data formats in more detail, and show how to generate and store dummy ranking data.
For representing ranking data, protobuffers are extensible structures suitable for storing data in a serialized format, either locally or in a distributed manner.
Ranking usually consists of features corresponding to each of the examples being sorted. In addition, features related to query, user or session are also useful for ranking. We refer to these as context features, as these are independent of the examples.
We use the popular tf.Example proto to represent the features for context, and each of the examples. We use the protobuffer, ExampleListWithContext (ELWC), to store context as a tf.Example proto and the list of examples to be ranked as a list of tf.Example protos.
ExampleListWithContext protbuffer is defined here.
Let us create some dummy data in ELWC format. We will use this dummy data to show how the proto looks like.
Download and install the TensorFlow 2 package.
In [0]:
print('Installing TensorFlow 2.1.0. This will take a minute, ignore the warnings.')
!pip install -q tensorflow==2.1.0
import tensorflow as tf
# This is needed for tensorboard compatibility.
!pip uninstall -y grpcio
!pip install -q grpcio>=1.24.3
In [0]:
from google.protobuf import text_format
CONTEXT = text_format.Parse(
"""
features {
feature {
key: "query_tokens"
value { bytes_list { value: ["this", "is", "a", "relevant", "question"] } }
}
}""", tf.train.Example())
In [0]:
EXAMPLES = [
text_format.Parse(
"""
features {
feature {
key: "document_tokens"
value { bytes_list { value: ["this", "is", "a", "relevant", "answer"] } }
}
feature {
key: "relevance"
value { int64_list { value: 5 } }
}
}""", tf.train.Example()),
text_format.Parse(
"""
features {
feature {
key: "document_tokens"
value { bytes_list { value: ["irrelevant", "data"] } }
}
feature {
key: "relevance"
value { int64_list { value: 1 } }
}
}""", tf.train.Example()),
]
In [0]:
try:
from tensorflow_serving.apis import input_pb2
except ImportError:
!pip install -q tensorflow-serving-api
from tensorflow_serving.apis import input_pb2
ELWC = input_pb2.ExampleListWithContext()
ELWC.context.CopyFrom(CONTEXT)
for example in EXAMPLES:
example_features = ELWC.examples.add()
example_features.CopyFrom(example)
In [0]:
print(ELWC)
Out[0]:
In [0]:
import six
import os
import numpy as np
try:
import tensorflow_ranking as tfr
except ImportError:
!pip install -q tensorflow_ranking
import tensorflow_ranking as tfr
Here we define the train and test paths, along with model hyperparameters.
In [0]:
# Store the paths to files containing training and test instances.
_TRAIN_DATA_PATH = "/tmp/train.tfrecords"
_TEST_DATA_PATH = "/tmp/test.tfrecords"
# Store the vocabulary path for query and document tokens.
_VOCAB_PATH = "/tmp/vocab.txt"
# The maximum number of documents per query in the dataset.
# Document lists are padded or truncated to this size.
_LIST_SIZE = 50
# The document relevance label.
_LABEL_FEATURE = "relevance"
# Padding labels are set negative so that the corresponding examples can be
# ignored in loss and metrics.
_PADDING_LABEL = -1
# Learning rate for optimizer.
_LEARNING_RATE = 0.05
# Parameters to the scoring function.
_BATCH_SIZE = 32
_HIDDEN_LAYER_DIMS = ["64", "32", "16"]
_DROPOUT_RATE = 0.8
_GROUP_SIZE = 1 # Pointwise scoring.
# Location of model directory and number of training steps.
_MODEL_DIR = "/tmp/ranking_model_dir"
_NUM_TRAIN_STEPS = 15 * 1000
The overall components of a Ranking Estimator are shown below.
The key components of the library are:
These are described in more details in the following sections.
Feature Columns are TensorFlow abstractions that are used to capture rich information about each feature. It allows for easy transformations for a diverse range of raw features and for interfacing with Estimators.
Consistent with our input formats for ranking, such as ELWC format, we create feature columns for context features and example features.
In [0]:
_EMBEDDING_DIMENSION = 20
def context_feature_columns():
"""Returns context feature names to column definitions."""
sparse_column = tf.feature_column.categorical_column_with_vocabulary_file(
key="query_tokens",
vocabulary_file=_VOCAB_PATH)
query_embedding_column = tf.feature_column.embedding_column(
sparse_column, _EMBEDDING_DIMENSION)
return {"query_tokens": query_embedding_column}
def example_feature_columns():
"""Returns the example feature columns."""
sparse_column = tf.feature_column.categorical_column_with_vocabulary_file(
key="document_tokens",
vocabulary_file=_VOCAB_PATH)
document_embedding_column = tf.feature_column.embedding_column(
sparse_column, _EMBEDDING_DIMENSION)
return {"document_tokens": document_embedding_column}
The input reader reads in data from persistent storage to produce raw dense and sparse tensors of appropriate type for each feature. Example features are represented by 3-D tensors (where dimensions correspond to queries, examples and feature values). Context features are represented by 2-D tensors (where dimensions correspond to queries and feature values).
In [0]:
def input_fn(path, num_epochs=None):
context_feature_spec = tf.feature_column.make_parse_example_spec(
context_feature_columns().values())
label_column = tf.feature_column.numeric_column(
_LABEL_FEATURE, dtype=tf.int64, default_value=_PADDING_LABEL)
example_feature_spec = tf.feature_column.make_parse_example_spec(
list(example_feature_columns().values()) + [label_column])
dataset = tfr.data.build_ranking_dataset(
file_pattern=path,
data_format=tfr.data.ELWC,
batch_size=_BATCH_SIZE,
list_size=_LIST_SIZE,
context_feature_spec=context_feature_spec,
example_feature_spec=example_feature_spec,
reader=tf.data.TFRecordDataset,
shuffle=False,
num_epochs=num_epochs)
features = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
label = tf.squeeze(features.pop(_LABEL_FEATURE), axis=2)
label = tf.cast(label, tf.float32)
return features, label
The transform function takes in the raw dense or sparse features from the input reader, applies suitable transformations to return dense representations for each feature. This is important before passing these features to a neural network, as neural networks layers usually take dense features as inputs.
The transform function handles any custom feature transformations defined by the user. For handling sparse features, like text data, we provide an easy utlity to create shared embeddings, based on the feature columns.
In [0]:
def make_transform_fn():
def _transform_fn(features, mode):
"""Defines transform_fn."""
context_features, example_features = tfr.feature.encode_listwise_features(
features=features,
context_feature_columns=context_feature_columns(),
example_feature_columns=example_feature_columns(),
mode=mode,
scope="transform_layer")
return context_features, example_features
return _transform_fn
Next, we turn to the scoring function which is arguably at the heart of a TF Ranking model. The idea is to compute a relevance score for a (set of) query-document pair(s). The TF-Ranking model will use training data to learn this function.
Here we formulate a scoring function using a feed forward network. The function takes the features of a single example (i.e., query-document pair) and produces a relevance score.
In [0]:
def make_score_fn():
"""Returns a scoring function to build `EstimatorSpec`."""
def _score_fn(context_features, group_features, mode, params, config):
"""Defines the network to score a group of documents."""
with tf.compat.v1.name_scope("input_layer"):
context_input = [
tf.compat.v1.layers.flatten(context_features[name])
for name in sorted(context_feature_columns())
]
group_input = [
tf.compat.v1.layers.flatten(group_features[name])
for name in sorted(example_feature_columns())
]
input_layer = tf.concat(context_input + group_input, 1)
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
cur_layer = input_layer
cur_layer = tf.compat.v1.layers.batch_normalization(
cur_layer,
training=is_training,
momentum=0.99)
for i, layer_width in enumerate(int(d) for d in _HIDDEN_LAYER_DIMS):
cur_layer = tf.compat.v1.layers.dense(cur_layer, units=layer_width)
cur_layer = tf.compat.v1.layers.batch_normalization(
cur_layer,
training=is_training,
momentum=0.99)
cur_layer = tf.nn.relu(cur_layer)
cur_layer = tf.compat.v1.layers.dropout(
inputs=cur_layer, rate=_DROPOUT_RATE, training=is_training)
logits = tf.compat.v1.layers.dense(cur_layer, units=_GROUP_SIZE)
return logits
return _score_fn
We have provided an implementation of several popular Information Retrieval evaluation metrics in the TF Ranking library, which are shown here. The user can also define a custom evaluation metric, as shown in the description below.
In [0]:
def eval_metric_fns():
"""Returns a dict from name to metric functions.
This can be customized as follows. Care must be taken when handling padded
lists.
def _auc(labels, predictions, features):
is_label_valid = tf_reshape(tf.greater_equal(labels, 0.), [-1, 1])
clean_labels = tf.boolean_mask(tf.reshape(labels, [-1, 1], is_label_valid)
clean_pred = tf.boolean_maks(tf.reshape(predictions, [-1, 1], is_label_valid)
return tf.metrics.auc(clean_labels, tf.sigmoid(clean_pred), ...)
metric_fns["auc"] = _auc
Returns:
A dict mapping from metric name to a metric function with above signature.
"""
metric_fns = {}
metric_fns.update({
"metric/ndcg@%d" % topn: tfr.metrics.make_ranking_metric_fn(
tfr.metrics.RankingMetricKey.NDCG, topn=topn)
for topn in [1, 3, 5, 10]
})
return metric_fns
We provide several popular ranking loss functions as part of the library, which are shown here. The user can also define a custom loss function, similar to ones in tfr.losses.
In [0]:
# Define a loss function. To find a complete list of available
# loss functions or to learn how to add your own custom function
# please refer to the tensorflow_ranking.losses module.
_LOSS = tfr.losses.RankingLossKey.APPROX_NDCG_LOSS
loss_fn = tfr.losses.make_loss_fn(_LOSS)
In [0]:
optimizer = tf.compat.v1.train.AdagradOptimizer(
learning_rate=_LEARNING_RATE)
def _train_op_fn(loss):
"""Defines train op used in ranking head."""
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
minimize_op = optimizer.minimize(
loss=loss, global_step=tf.compat.v1.train.get_global_step())
train_op = tf.group([update_ops, minimize_op])
return train_op
ranking_head = tfr.head.create_ranking_head(
loss_fn=loss_fn,
eval_metric_fns=eval_metric_fns(),
train_op_fn=_train_op_fn)
In [0]:
model_fn = tfr.model.make_groupwise_ranking_fn(
group_score_fn=make_score_fn(),
transform_fn=make_transform_fn(),
group_size=_GROUP_SIZE,
ranking_head=ranking_head)
In [0]:
def train_and_eval_fn():
"""Train and eval function used by `tf.estimator.train_and_evaluate`."""
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=1000)
ranker = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=_MODEL_DIR,
config=run_config)
train_input_fn = lambda: input_fn(_TRAIN_DATA_PATH)
eval_input_fn = lambda: input_fn(_TEST_DATA_PATH, num_epochs=1)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=_NUM_TRAIN_STEPS)
eval_spec = tf.estimator.EvalSpec(
name="eval",
input_fn=eval_input_fn,
throttle_secs=15)
return (ranker, train_spec, eval_spec)
In [0]:
! rm -rf "/tmp/ranking_model_dir" # Clean up the model directory.
ranker, train_spec, eval_spec = train_and_eval_fn()
tf.estimator.train_and_evaluate(ranker, train_spec, eval_spec)
In [0]:
%load_ext tensorboard
%tensorboard --logdir="/tmp/ranking_model_dir" --port 12345
A sample tensorboard output is shown here, with the ranking metrics.
We show how to generate predictions over the features of a dataset. We assume that the label is not present and needs to be inferred using the ranking model.
Similar to the input_fn
used for training and evaluation, predict_input_fn
reads in data in ELWC format and stored as TFRecords to generate features. We set number of epochs to be 1, so that the generator stops iterating when it reaches the end of the dataset. Also the datapoints are not shuffled while reading, so that the behavior of the predict()
function is deterministic.
In [0]:
def predict_input_fn(path):
context_feature_spec = tf.feature_column.make_parse_example_spec(
context_feature_columns().values())
example_feature_spec = tf.feature_column.make_parse_example_spec(
list(example_feature_columns().values()))
dataset = tfr.data.build_ranking_dataset(
file_pattern=path,
data_format=tfr.data.ELWC,
batch_size=_BATCH_SIZE,
list_size=_LIST_SIZE,
context_feature_spec=context_feature_spec,
example_feature_spec=example_feature_spec,
reader=tf.data.TFRecordDataset,
shuffle=False,
num_epochs=1)
features = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
return features
We generate predictions on the test dataset, where we only consider context and example features and predict the labels. The predict_input_fn
generates predictions on a batch of datapoints. Batching allows us to iterate over large datasets which cannot be loaded in memory.
In [0]:
predictions = ranker.predict(input_fn=lambda: predict_input_fn("/tmp/test.tfrecords"))
ranker.predict
returns a generator, which we can iterate over to create predictions, till the generator is exhausted.
In [0]:
x = next(predictions)
assert(len(x) == _LIST_SIZE) ## Note that this includes padding.