Licensed under the Apache License, Version 2.0 (the "License");


In [0]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Fine-Tuning the Text-To-Text Transfer Transformer (T5) for Closed-Book Question Answering

Or: What does T5 know?

The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction, all on a free Google Cloud TPU .

Background

T5 was introduced in the paper Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. In that paper, we provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.

We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called C4. C4 is based on a massive scrape of the web produced by Common Crawl. Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.

How can we assess what T5 knows?

As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.

One way to use this text-to-text framework is on reading comprehension problems, where the model is fed some context along with a question and is trained to predict the question's answer. For example, we might feed the model the text from the Wikipedia article about Hurrican Connie along with the question "On what date did Hurricane Connie occur?" and train the model to predict the answer "August 3rd, 1955". A related task is open-domain question answering (QA) where the model is not provided with this oracle context. Typically, open-domain QA systems include a mechanism to look up information in an external knowledge source. This setting is similar to an "open-book" exam.

In this notebook, we'll be training T5 on a variant of this task which we call closed-book question answering. In closed-book QA, we feed the model a question without any context or access to external knowledge and train it to predict the answer. Since the model doesn't receive any context, the primary way it can learn to answer these questions is based on the "knowledge" it obtained during pre-training. We don't expect T5 to contain super specific information, so we will be focusing on two question-answering datasets which largely include trivia questions (i.e. facts about well-known subjects). Similar investigations have recently been done to test the knowledge stored by BERT and GPT-2.

T5 was not pre-trained on closed-book QA, so in this notebook we'll first create two new tasks and then use the t5 library to fine-tune, evaluate, and obtain predictions from T5. In the end, T5's performance on closed-book QA can give us a sense of what kind (and how much) information T5 managed to learn during pre-training.

State-of-the-art Results

We published a more in-depth investigation of closed-book QA with T5 where we achieved SOTA on open-domain variants of WebQuestions and TriviaQA in addition to surpisingly strong results on Natural Questions. The code in this notebook is a simplified version of those experiments but still produces good results.

For code to reproduce our best results, please see the t5_closed_book_qa repo.

Caveats

  • While we provide instructions for running on a Cloud TPU via Colab for free, a Google Cloud Storage (GCS) bucket is required for storing model parameters and data. The GCS free tier provides 5 GB of storage, which should be enough to train the large model and smaller but not the 3B or 11B parameter models. You can use part of your initial $300 credit to get more space.
  • The Cloud TPU provided by Colab (a v2-8) does not have enough memory to fine-tune the 11B parameter model. For this model, you will need to fine-tune inside of a GCP instance (see README).

Set Up

  Train on TPU

  1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the BASE_DIR parameter in the following form. There is a free tier if you do not yet have an account.

  2. On the main menu, click Runtime and select Change runtime type. Set "TPU" as the hardware accelerator.

  3. Run the following cell and follow instructions to:
    • Set up a Colab TPU running environment
    • Verify that you are connected to a TPU device
    • Upload your credentials to TPU to access your GCS bucket

In [0]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

BASE_DIR = "gs://" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v2-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Creating new Tasks and Mixture

Two core components of the T5 library are Task and Mixture objects.

A Task is a dataset along with preprocessing functions and evaluation metrics. A Mixture is a collection of Task objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent Tasks.

For this example, we will fine-tune the model to do closed-book question answering.

Natural Questions

Natural Questions (NQ) is a challenging corpus for open-domain QA. Each example includes a question along with an entire Wikipedia article that may or may not contain its answer. The goal is to produce the correct answer given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the answers from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove questions with multiple answers, and to do a bit of cleaning of the text.


In [0]:
import gzip
import json

# Public directory of Natural Questions data on GCS.
NQ_JSONL_DIR = "gs://natural_questions/v1.0-simplified/"
NQ_SPLIT_FNAMES = {
    "train": "simplified-nq-train.jsonl.gz",
    "validation": "nq-dev-all.jsonl.gz"
}
nq_counts_path = os.path.join(DATA_DIR, "nq-counts.json")
nq_tsv_path = {
    "train": os.path.join(DATA_DIR, "nq-train.tsv"),
    "validation": os.path.join(DATA_DIR, "nq-validation.tsv")
}

def nq_jsonl_to_tsv(in_fname, out_fname):

  def extract_answer(tokens, span):
    """Reconstruct answer from token span and remove extra spaces."""
    start, end = span["start_token"], span["end_token"]  
    ans = " ".join(tokens[start:end])
    # Remove incorrect spacing around punctuation.
    ans = ans.replace(" ,", ",").replace(" .", ".").replace(" %", "%")
    ans = ans.replace(" - ", "-").replace(" : ", ":").replace(" / ", "/")
    ans = ans.replace("( ", "(").replace(" )", ")")
    ans = ans.replace("`` ", "\"").replace(" ''", "\"")
    ans = ans.replace(" 's", "'s").replace("s ' ", "s' ")
    return ans

  count = 0
  with tf.io.gfile.GFile(in_fname, "rb") as infile,\
       tf.io.gfile.GFile(out_fname, "w") as outfile:
    for line in gzip.open(infile):
      ex = json.loads(line)
      # Remove any examples with more than one answer.
      if len(ex['annotations'][0]['short_answers']) != 1:
        continue
      # Questions in NQ do not include a question mark.
      question = ex["question_text"] + "?"
      answer_span = ex['annotations'][0]['short_answers'][0]
      # Handle the two document formats in NQ (tokens or text).
      if "document_tokens" in ex:
        tokens = [t["token"] for t in ex["document_tokens"]]
      elif "document_text" in ex:
        tokens = ex["document_text"].split(" ")
      answer = extract_answer(tokens, answer_span)
      # Write this line as <question>\t<answer>
      outfile.write("%s\t%s\n" % (question, answer))
      count += 1
      tf.logging.log_every_n(
          tf.logging.INFO,
          "Wrote %d examples to %s." % (count, out_fname),
          1000)
    return count

if tf.io.gfile.exists(nq_counts_path):
  # Used cached data and counts.
  tf.logging.info("Loading NQ from cache.")
  num_nq_examples = json.load(tf.io.gfile.GFile(nq_counts_path))
else:
  # Create TSVs and get counts.
  tf.logging.info("Generating NQ TSVs.")
  num_nq_examples = {}
  for split, fname in NQ_SPLIT_FNAMES.items():
    num_nq_examples[split] = nq_jsonl_to_tsv(
        os.path.join(NQ_JSONL_DIR, fname), nq_tsv_path[split])
  json.dump(num_nq_examples, tf.io.gfile.GFile(nq_counts_path, "w"))


I1206 00:11:00.169766 248738 <ipython-input-3-45e03d923fbf>:51] Loading NQ from cache.

Next, we define a function to load the TSV data as a tf.data.Dataset in TensorFlow.


In [0]:
def nq_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  # Split each "<question>\t<answer>" example into (question, answer) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"question": ... "answer": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(nq_dataset_fn("validation").take(5)):
  print(ex)


A few raw validation examples...
{'question': b'what do the 3 dots mean in math?', 'answer': b'the therefore sign (\xe2\x88\xb4) is generally used before a logical consequence, such as the conclusion of a syllogism'}
{'question': b'who is playing the halftime show at super bowl 2016?', 'answer': b'Coldplay with special guest performers Beyonc\xc3\xa9 and Bruno Mars'}
{'question': b'who won the 2017 sports personality of the year?', 'answer': b'Mo Farah'}
{'question': b'where was the world economic forum held this year?', 'answer': b'Davos, a mountain resort in Graub\xc3\xbcnden, in the eastern Alps region of Switzerland'}
{'question': b'who has made the most premier league appearances?', 'answer': b'Gareth Barry'}

Now, we write a preprocess function to convert the examples in the tf.data.Dataset into a text-to-text format, with both inputs and targets fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the answers are sometimes formatted in odd ways. Finally, we prepend 'trivia question:' to the inputs so that the model knows what task it's trying to solve.


In [0]:
def trivia_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["trivia question: ", normalize_text(ex["question"])]),
        "targets": normalize_text(ex["answer"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a Task.


In [0]:
t5.data.TaskRegistry.add(
    "nq_context_free",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=nq_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[trivia_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_nq_examples
)

Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [0]:
nq_task = t5.data.TaskRegistry.get("nq_context_free")
ds = nq_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed validation examples...
{'inputs_plaintext': b'trivia question: what is the average height of a chinese man?', 'inputs': array([22377,   822,    10,   125,    19,     8,  1348,  3902,    13,
           3,     9,     3,  1436,  1496,    15,   388,    58,     1]), 'targets_plaintext': b'167.1 cm (5 ft 6 in)', 'targets': array([  898, 25059,  2446,  9209,     3,    89,    17,   431,    16,
          61,     1])}
{'inputs_plaintext': b'trivia question: what is the population of fayetteville north carolina?', 'inputs': array([22377,   822,    10,   125,    19,     8,  2074,    13,     3,
          89,     9,    63,  1954,  1420,  3457,   443, 12057,     9,
          58,     1]), 'targets_plaintext': b'204,408 in 2013', 'targets': array([    3, 26363,     6,  2445,   927,    16,  2038,     1])}
{'inputs_plaintext': b'trivia question: capital of georgia the former soviet republic 7 letters?', 'inputs': array([22377,   822,    10,  1784,    13,   873,  1677,    23,     9,
           8,  1798,    78,  5914,    17, 20237,   489,  5487,    58,
           1]), 'targets_plaintext': b'tbilisi', 'targets': array([   3,   17, 3727,  159,   23,    1])}
{'inputs_plaintext': b'trivia question: who plays jill bigelow in line of duty?', 'inputs': array([22377,   822,    10,   113,  4805,     3,   354,  1092,   600,
          15,  3216,    16,   689,    13,  5461,    58,     1]), 'targets_plaintext': b'polly walker', 'targets': array([ 5492,    63,     3, 24063,     1])}
{'inputs_plaintext': b'trivia question: when did we first put a rover on mars?', 'inputs': array([22377,   822,    10,   116,   410,    62,   166,   474,     3,
           9,     3,    52,  1890,    30,  8113,    58,     1]), 'targets_plaintext': b'january 2004', 'targets': array([   3, 7066,   76, 1208, 4406,    1])}

Note: Instead of defining nq_dataset_fn and above, we also could have used the TextLineTask class with the parse_tsv preprocessor for equivalent results as follows:

t5.data.TaskRegistry.add(
    "nq_context_free",
    t5.data.TextLineTask,
    split_to_filepattern=nq_tsv_path,
    text_preprocessor=[
      functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=["question", "answer"]),
      trivia_preprocessor
    ],
    postprocess_fn=t5.data.postprocessors.lower_text, 
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples
)

TriviaQA

A second dataset we will use is related to TriviaQA. It is also intended for reading comprehension, but, once again, we will modify the task here by ignoring the provided context.

Since the dataset has been imported into TensorFlow Datasets (TFDS), we can let it handle the data parsing for us. It will take a few minutes to download and preprocess the first time, but we'll be able to access it instantly from our data directory afterward.


In [0]:
ds = tfds.load(
    "trivia_qa/unfiltered.nocontext",
    data_dir=DATA_DIR,
    # Download data locally for preprocessing to avoid using GCS space.
    download_and_prepare_kwargs={"download_dir": "./downloads"})
print("A few raw validation examples...")
for ex in tfds.as_numpy(ds["validation"].take(2)):
  print(ex)


A few raw validation examples...
{'answer': {'aliases': array([b'Torquemada (disambiguation)', b'Torquemada'], dtype=object), 'matched_wiki_entity_name': b'', 'normalized_aliases': array([b'torquemada', b'torquemada disambiguation'], dtype=object), 'normalized_matched_wiki_entity_name': b'', 'normalized_value': b'torquemada', 'type': b'WikipediaEntity', 'value': b'Torquemada'}, 'entity_pages': {'doc_source': array([], dtype=object), 'filename': array([], dtype=object), 'title': array([], dtype=object), 'wiki_context': array([], dtype=object)}, 'question': b'In 1483, who was appointed the first grand inquisitor of the Spanish Inquisition?', 'question_id': b'qw_16011', 'question_source': b'http://www.quizwise.com/', 'search_results': {'description': array([], dtype=object), 'filename': array([], dtype=object), 'rank': array([], dtype=int32), 'search_context': array([], dtype=object), 'title': array([], dtype=object), 'url': array([], dtype=object)}}
{'answer': {'aliases': array([b'Austerlitz (disambiguation)', b'Austerlitz', b'AUSTERLITZ'],
      dtype=object), 'matched_wiki_entity_name': b'', 'normalized_aliases': array([b'austerlitz', b'austerlitz disambiguation'], dtype=object), 'normalized_matched_wiki_entity_name': b'', 'normalized_value': b'austerlitz', 'type': b'WikipediaEntity', 'value': b'AUSTERLITZ'}, 'entity_pages': {'doc_source': array([], dtype=object), 'filename': array([], dtype=object), 'title': array([], dtype=object), 'wiki_context': array([], dtype=object)}, 'question': b'Which celebrated battle was fought near Brno on 2nd December 1805?', 'question_id': b'dpql_4053', 'question_source': b'https://derbyshirepubquizleague.wordpress.com/', 'search_results': {'description': array([], dtype=object), 'filename': array([], dtype=object), 'rank': array([], dtype=int32), 'search_context': array([], dtype=object), 'title': array([], dtype=object), 'url': array([], dtype=object)}}

As with Natural Questions, we need to preprocess the raw examples into inputs and targets features. We can reuse the trivia_preprocessor above, but first we need to convert the TriviaQA examples into the correct format, ignoring the fields we don't need for our task.

We'll then define our Task and print out a few preprocessed examples from the validation set.

Note that we do not need to specify the splits or number of examples since that information is provided by TFDS.


In [0]:
def tiviaqa_extract_qa(ds):
  def exract_qa(ex):
    return {
        "question": ex["question"],
        "answer": ex["answer"]["value"]
    }
  return ds.map(exract_qa, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.add(
    "triviaqa_context_free",
    # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
    t5.data.TfdsTask,
    tfds_name="trivia_qa/unfiltered.nocontext:1.1.0",
    tfds_data_dir=DATA_DIR,
    text_preprocessor=[tiviaqa_extract_qa, trivia_preprocessor],
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy]
)

# Load and print a few examples.
triviaqa_task = t5.data.TaskRegistry.get("triviaqa_context_free")
ds = triviaqa_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(3)):
  print(ex)


A few preprocessed validation examples...
{'inputs_plaintext': b'trivia question: what does a farrier do?', 'inputs': array([22377,   822,    10,   125,   405,     3,     9,   623,  6711,
         103,    58,     1]), 'targets_plaintext': b'he shoes horses', 'targets': array([    3,    88,  4439, 10235,     1])}
{'inputs_plaintext': b'trivia question: what is the name of the wooden panelled lining applied to a room', 'inputs': array([22377,   822,    10,   125,    19,     8,   564,    13,     8,
        5726,  2952,  1361,     3,  9424,  2930,    12,     3,     9,
         562,     1]), 'targets_plaintext': b'wainscotting', 'targets': array([    3,   210, 13676, 10405,    53,     1])}
{'inputs_plaintext': b'trivia question: how did gus grissom, ed white and roger b. chaffee die in 1967?', 'inputs': array([22377,   822,    10,   149,   410,     3,  1744,     7, 19116,
       10348,     6,     3,    15,    26,   872,    11,     3,  3822,
          49,     3,   115,     5,     3,  3441,  7398,    15,    67,
          16, 18148,    58,     1]), 'targets_plaintext': b'burned to death', 'targets': array([16644,    12,  1687,     1])}

Dataset Mixture

We now create a Mixture from the above Tasks, which we will fine-tune on.

There are different ways to automatically set the rate (for example, based on the number of examples using rate_num_examples), but we will just hardcode an equal mixture for simplicity.


In [0]:
t5.data.MixtureRegistry.remove("trivia_all")
t5.data.MixtureRegistry.add(
    "trivia_all",
    ["nq_context_free", "triviaqa_context_free"],
     default_rate=1.0
)

Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of closed-book QA tasks.

First, we'll instantiate a Model object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more FINETUNE_STEPS below.

Caveats

  • Due to its memory requirements, you will not be able to train the 11B parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see README).
  • Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the 3B parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
  • While large can achieve decent results, it is recommended that you fine-tune at least the 3B parameter model.

Define Model


In [0]:
MODEL_SIZE = "3B" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

Before we continue, let's load a TensorBoard visualizer so that we can keep monitor our progress. The page should automatically update as fine-tuning and evaluation proceed.


In [0]:
if ON_CLOUD:
  %reload_ext tensorboard
  import tensorboard as tb
tb.notebook.start("--logdir " + MODELS_DIR)

Fine-tune

We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more FINETUNE_STEPS you use, the longer it will take.

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.


In [0]:
FINETUNE_STEPS = 25000 #@param {type: "integer"}

model.finetune(
    mixture_or_task_name="trivia_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

Expected Results [SPOILER ALERT]

Below are the expected accuracies on the Natural Question (NQ) and TriviQA validation sets for various model sizes. The full 11B model produces the exact text of the answer 34.5% and 25.1% of the time on TriviaQA and NQ, respectively. The 3B parameter model, which is the largest that can be trained with a free Cloud TPU in Colab, achieves 29.7% and 23.7%, respectively.

In reality, the model performs better than this since requiring exact match is too strict of a metric, as you’ll see in the examples below. This helps to explain why the model appears to perform better on TriviaQA than NQ, as the latter tends to include more long-form answers extracted from the context.

Please see our paper on closed-book QA where achieved even better results.

Evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.


In [0]:
# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="trivia_all",
    checkpoint_steps="all"
)

Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an exact match of the predicted answer and the ground-truth answer. As a result, some of the answers are semantically correct but are counted wrong by the exact match score.


In [0]:
import random

def print_random_predictions(task_name, n=10):
  """Print n predictions from the validation split of a task."""
  # Grab the dataset for this task.
  ds = t5.data.TaskRegistry.get(task_name).get_dataset(
      split="validation",
      sequence_length={"inputs": 128, "targets": 32},
      shuffle=False)

  def _prediction_file_to_ckpt(path):
    """Extract the global step from a prediction filename."""
    return int(path.split("_")[-2])

  # Grab the paths of all logged predictions.
  prediction_files = tf.io.gfile.glob(
      os.path.join(
          MODEL_DIR,
          "validation_eval/%s_*_predictions" % task_name))
  # Get most recent prediction file by sorting by their step.
  latest_prediction_file = sorted(
      prediction_files, key=_prediction_file_to_ckpt)[-1]

  # Collect (inputs, targets, prediction) from the dataset and predictions file
  results = []
  with tf.io.gfile.GFile(latest_prediction_file) as preds:
    for ex, pred in zip(tfds.as_numpy(ds), preds):
      results.append((tf.compat.as_text(ex["inputs_plaintext"]),
                      tf.compat.as_text(ex["targets_plaintext"]),
                      pred.strip()))

  print("<== Random predictions for %s using checkpoint %s ==>\n" %
        (task_name, 
         _prediction_file_to_ckpt(latest_prediction_file)))

  for inp, tgt, pred in random.choices(results, k=10):
    print("Input:", inp)
    print("Target:", tgt)
    print("Prediction:", pred)
    print("Counted as Correct?", tgt == pred)
    print()

print_random_predictions("triviaqa_context_free")
print_random_predictions("nq_context_free")


<== Random predictions for triviaqa_context_free using checkpoint 1100000 ==>

Input: trivia question: jackpot counter, ghost drop and drop zone are all terms used in which uk television game show?
Target: tipping point
Prediction: countdown
Counted as Correct? False

Input: trivia question: cursed to sail around the cape of good hope, which ghost ship is the theme of an 1841 opera by richard wagner?
Target: the flying dutchman
Prediction: baron von munchhausen
Counted as Correct? False

Input: trivia question: at what fret are found the same notes as the open strings, but an octave higher, on a standard guitar?
Target: 12th
Prediction: 12th
Counted as Correct? True

Input: trivia question: how many legs does a ladybird have?
Target: six
Prediction: six
Counted as Correct? True

Input: trivia question: in which city’s harbour was the ship queen elizabeth ravaged by fire in 1972?
Target: hong kong
Prediction: hong kong
Counted as Correct? True

Input: trivia question: what are the three largest islands in the world beginning with the letter n
Target: new guinea, north island
Prediction: new zealand; namibia and nova scotia
Counted as Correct? False

Input: trivia question: lenny bruce was in what field of entertainment in the 1960s?
Target: standup comedy
Prediction: comedy
Counted as Correct? False

Input: trivia question: in which sea are the cayman islands?
Target: caribbean
Prediction: caribbean
Counted as Correct? True

Input: trivia question: what is an astronomical event that occurs when one celestial object moves into the shadow of another?
Target: eclipse
Prediction: lunar eclipse
Counted as Correct? False

Input: trivia question: which tv cartoon series was about a meek janitor who led a double life as an unfortunate super-detective?
Target: hong kong fuey
Prediction: scooby-doo
Counted as Correct? False

<== Random predictions for nq_context_free using checkpoint 1100000 ==>

Input: trivia question: who is known as the super fast boy in the series the icredible?
Target: dashiell robert parr/dash
Prediction: dash
Counted as Correct? False

Input: trivia question: who played santa in the santa clause movies?
Target: tim allen
Prediction: tim allen
Counted as Correct? True

Input: trivia question: who has sold more albums kelly or carrie?
Target: carrie underwood
Prediction: carrie underwood
Counted as Correct? True

Input: trivia question: when did sweet caroline start at red sox games?
Target: at least 1997
Prediction: 2004
Counted as Correct? False

Input: trivia question: who plays mr wilson in dennis the menace?
Target: joseph sherrard kearns
Prediction: joseph sherrard kearns
Counted as Correct? True

Input: trivia question: who had a baby at 100 in the bible?
Target: abraham
Prediction: sarah
Counted as Correct? False

Input: trivia question: who is doing 2018 super bowl half time show?
Target: justin timberlake
Prediction: justin timberlake
Counted as Correct? True

Input: trivia question: what is the official slogan for the 2018 winter olympics?
Target: passion. connected.
Prediction: every step counts
Counted as Correct? False

Input: trivia question: ray charles hit the road jack album name?
Target: ray charles greatest hits
Prediction: the road jack album
Counted as Correct? False

Input: trivia question: who sang the theme song to step by step?
Target: jesse frederick james conaway
Prediction: frederick and teresa james
Counted as Correct? False

Predict

Now that we have fine-tuned the model, we can feed T5 arbitrary questions and have it predict the answers!

There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.

To avoid this overhead, you might consider exporting a SavedModel and running it on Cloud ML Engine.


In [0]:
question_1 = "Where is the Google headquarters located?" #@param {type:"string"}
question_2 = "What is the most populous country in the world?" #@param {type:"string"}
question_3 = "Who are the 4 members of The Beatles?" #@param {type:"string"}
question_4 = "How many teeth do humans have?" #@param {type:"string"}

questions = [question_1, question_2, question_3, question_4]

now = time.time()
# Write out the supplied questions to text files.
predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing by prepending "triviaqa question:".
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
  for q in questions:
    f.write("trivia question: %s\n" % q.lower())

# Ignore any logging so that we only see the model's answers to the questions.
with tf_verbosity_level('ERROR'):
  model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
  model.predict(
      input_file=predict_inputs_path,
      output_file=predict_outputs_path,
      # Select the most probable output token at each step.
      temperature=0,
  )

# The output filename will have the checkpoint appended so we glob to get 
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  for q, a in zip(questions, f):
    if q:
      print("Q: " + q)
      print("A: " + a)
      print()


Predictions using checkpoint 1100000:

Q: Where is the Google headquarters located?
A: mountain view, california


Q: What is the most populous country in the world?
A: china


Q: Who are the 4 members of The Beatles?
A: john lennon, paul mccartney, george harrison and ringo starr


Q: How many teeth do humans have?
A: 30


Export Model for Serving

As mentioned in the previous section, exporting a SavedModel can be useful for improving performance during inference or allowing your model to be deployed on a variety of platforms (e.g., TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub).

Note: we currently only support exporting a SavedModel that runs on both CPU and GPU, not TPU.

Export SavedModel

We first export the SavedModel. We set a batch size of 1 for simplicity, but it may be more efficient to use a larger batch size if you want to handle multiple requests per call.

For 3B and 11B models the export will take approximately 30-45 minutes.


In [0]:
export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

Load SavedModel

One way to test our model is to load it either in eager mode or a TF 1.x session so that we can repeatedly predict from the model without the overhead of loading the graph and weights each time.

We pay the overhead once here, but it shouldn't take more than a few minutes.

Optional: Switch to GPU Runtime

Changing the runtime type to GPU in the Runtime menu above before loading the SavedModel will speed up inference by using the GPU instead of CPU.


In [0]:
#@title Optional: Run this cell to re-initialize if you switched to GPU runtime.
%tensorflow_version 2.x
!pip install tensorflow-text
from google.colab import auth
auth.authenticate_user()

In [0]:
import tensorflow as tf
import tensorflow_text  # Required to run exported model.

def load_predict_fn(model_path):
  if tf.executing_eagerly():
    print("Loading SavedModel in eager mode.")
    imported = tf.saved_model.load(model_path, ["serve"])
    return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
  else:
    print("Loading SavedModel in tf 1.x graph mode.")
    tf.compat.v1.reset_default_graph()
    sess = tf.compat.v1.Session()
    meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
    signature_def = meta_graph_def.signature_def["serving_default"]
    return lambda x: sess.run(
        fetches=signature_def.outputs["outputs"].name, 
        feed_dict={signature_def.inputs["input"].name: x}
    )

predict_fn = load_predict_fn(saved_model_path)

Predict

We can now call the predict method with different inputs each time and relatively quickly get results.


In [0]:
def answer(question):
  return predict_fn([question])[0].decode('utf-8')

for question in ["trivia question: where is the google headquarters?",
                 "trivia question: what is the most populous country in the world?",
                 "trivia question: who are the 4 members of the beatles?",
                 "trivia question: how many teeth do humans have?"]:
    print(answer(question))

Deploy SavedModel

You can now deploy your SavedModel for serving (e.g., with TensorFlow Serving).