This codelab will guide you through the implementation of a sequence-to-sequence model using Lingvo.
Sequence-to-sequence models map input sequences of arbitrary length to output sequences of arbitrary length. Example uses of sequence-to-sequence models include machine translation, which maps a sequence of words from one language into a sequence of words in another language with the same meaning; speech recognition, which maps a sequence of acoustic features into a sequence of words; and text summarization, which maps a sequence of words into a shorter sequence which conveys the same meaning.
In this codelab, you will create a model which restores punctuation and capitalization to text which has been lowercased and stripped of punctuation. For example, given the following text:
she asked do you know the way to san jose
The model will output the following properly-punctuated-and-capitalized text:
She asked, "Do you know the way to San Jose"?
We will train an RNMT+ model based off of "The Best of Both Worlds: Combining Recent Advances in Neural Machine Translation. (Chen et al., 2018)".
The main goal of this codelab is to teach you how to define and train sequence-to-sequence models in Lingvo. We do not aim to teach either Python or Tensorflow, and no sophisticated Python or Tensorflow programming will be required. However, the following will be helpful in understanding this codelab:
This codelab will teach you the following:
This codelab does not:
To start with, we need to connect this Colab notebook with Lingvo.
mkdir -p /tmp/lingvo_codelab && cd /tmp/lingvo_codelab
pip3 install lingvo
python3 -m lingvo.ipython_kernel
Finally, on the top right hand side of this Colab notebook, open the dropdown beside "CONNECT" and select "Connect to local runtime...", enter http://localhost:8888 and press CONNECT.
You should now see the words "CONNECTED" and be able to execute the following cell.
In [0]:
import lingvo
In order to train a sequence-to-sequence model, we need a set of pairs of source and target sequences. For this codelab, our source sequences will be text which has been lowercased and had its punctuation removed, and the target sequences will be the original sentences, with their original casing and punctuation.
Since neural networks require numeric inputs, we will also need a tokenization scheme mapping the sequence of characters to a sequence of numbers. In this codelab, we will use a pre-trained word-piece model.
We will use the Brown Corpus as the source of our training data. Run the following cell to download and preprocess the dataset. The script will generate train.txt and test.txt containing the training and test data at an 80:20 split with individual sentences on each line.
In [0]:
!python3 -m lingvo.tasks.punctuator.tools.download_brown_corpus --outdir=/tmp/punctuator_data
!curl -O https://raw.githubusercontent.com/tensorflow/lingvo/master/lingvo/tasks/punctuator/params/brown_corpus_wpm.16000.vocab
In order to train a model, we need an input generator that provides structured mini-batches of source-target pairs. The input generator handles all the processing necessary to generate numeric data that can be fed to the model. This includes:
Fortunately, the majority of this is handled in the background by Lingvo. We only need to specify how the data should be processed.
Input generators are subclasses of BaseInputGenerator found in lingvo/core/base_input_generator.py and have the following structure:
Params classmethod that returns a default Params object for configuring the input generator. Experiment configurations inside Lingvo are controlled using these Params objects._InputBatch method that returns a NestedMap containing the input batch. NestedMap is an arbitrarily nested map structure used throughout Lingvo._PreprocessInputBatch method that preprocesses the batch returned by _InputBatch.Here is an example of the input generator for the Punctuator task, found in lingvo/tasks/punctuator/input_generator.py.
Run the cell below to write the file to disk.
In [0]:
%%writefile input_generator.py
import string
import lingvo.compat as tf
from lingvo.core import base_input_generator
from lingvo.core import base_layer
from lingvo.core import generic_input
from lingvo.core import py_utils
from lingvo.core import tokenizers
class PunctuatorInput(base_input_generator.BaseSequenceInputGenerator):
"""Reads text line by line and processes them for the punctuator task."""
@classmethod
def Params(cls):
"""Defaults params for PunctuatorInput."""
p = super(PunctuatorInput, cls).Params()
p.tokenizer = tokenizers.WpmTokenizer.Params()
return p
def _ProcessLine(self, line):
"""A single-text-line processor.
Gets a string tensor representing a line of text that have been read from
the input file, and splits it to graphemes (characters).
We use original characters as the target labels, and the lowercased and
punctuation-removed characters as the source labels.
Args:
line: a 1D string tensor.
Returns:
A list of tensors, in the expected order by __init__.
"""
# Tokenize the input into integer ids.
# tgt_ids has the start-of-sentence token prepended, and tgt_labels has the
# end-of-sentence token appended.
tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
tf.convert_to_tensor([line]))
def Normalize(line):
# Lowercase and remove punctuation.
line = line.lower().translate(None, string.punctuation.encode('utf-8'))
# Convert multiple consecutive spaces to a single one.
line = b' '.join(line.split())
return line
normalized_line = tf.py_func(Normalize, [line], tf.string, stateful=False)
_, src_labels, src_paddings = self.StringsToIds(
tf.convert_to_tensor([normalized_line]), is_source=True)
# The model expects the source without a start-of-sentence token.
src_ids = src_labels
# Compute the length for bucketing.
bucket_key = tf.cast(
tf.round(
tf.maximum(
tf.reduce_sum(1.0 - src_paddings),
tf.reduce_sum(1.0 - tgt_paddings))), tf.int32)
tgt_weights = 1.0 - tgt_paddings
# Return tensors in an order consistent with __init__.
out_tensors = [
src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights
]
return [tf.squeeze(t, axis=0) for t in out_tensors], bucket_key
def _DataSourceFromFilePattern(self, file_pattern):
"""Create the input processing op.
Args:
file_pattern: The file pattern to use as input.
Returns:
an operation that when executed, calls `_ProcessLine` on a line read
from `file_pattern`.
"""
return generic_input.GenericInput(
file_pattern=file_pattern,
processor=self._ProcessLine,
# Pad dimension 0 to the same length.
dynamic_padding_dimensions=[0] * 6,
# The constant values to use for padding each of the outputs.
dynamic_padding_constants=[0, 1, 0, 1, 0, 0],
**self.CommonInputOpArgs())
def __init__(self, params):
super(PunctuatorInput, self).__init__(params)
# Build the input processing graph.
(self._src_ids, self._src_paddings, self._tgt_ids, self._tgt_paddings,
self._tgt_labels,
self._tgt_weights), self._bucket_keys = self._BuildDataSource()
self._sample_ids = tf.range(0, self.InfeedBatchSize(), 1)
def InfeedBatchSize(self):
return tf.shape(self._src_ids)[0]
def _InputBatch(self):
"""Returns a single batch as a `.NestedMap` to be passed to the model."""
ret = py_utils.NestedMap()
ret.bucket_keys = self._bucket_keys
ret.src = py_utils.NestedMap()
ret.src.ids = tf.cast(self._src_ids, dtype=tf.int32)
ret.src.paddings = self._src_paddings
ret.tgt = py_utils.NestedMap()
ret.tgt.ids = self._tgt_ids
ret.tgt.labels = tf.cast(self._tgt_labels, dtype=tf.int32)
ret.tgt.weights = self._tgt_weights
ret.tgt.paddings = self._tgt_paddings
return ret
Next, we need to define the network structure for the task. The network is a nested structure of layers. Most classes in Lingvo are subclasses of BaseLayer found in lingvo/core/base_layer.py and inherit the following:
cls: the python class that the Params object is associated with. This can be used to construct an instance of the class;name: the name of this layer;dtype: the default dtype to use when creating variables.__init__ constructor. All child layers and variables should be created here.CreateVariable method that is called to create variables.CreateChild method that is called to create child layers.FProp method that implements forward propagation through the layer.As a reference, many examples of layers can be found in lingvo/core/layers.py, lingvo/core/attention.py, and lingvo/core/rnn_layers.py.
The root layer for the network should be a subclass of BaseTask found in lingvo/core/base_model.py, and implements the following:
ComputePredictions method that takes the current variable values (theta) and input_batch and returns the network predictions.ComputeLoss method that takes theta, input_batch, and the predictions returned from ComputePredictions and returns a dictionary of scalar metrics, one of which should be loss. These scalar metrics are exported to TensorBoard as summaries.Decode method for creating a separate graph for decoding. For example, training and evaluation might use teacher forcing while decoding might not.Inference method that returns a graph with feeds and fetches that can be used together with a saved checkpoint for inference. This differs from Decode in that it can be fed data directly instead of using data from the input generator.
This codelab uses the existing networks from lingvo/tasks/punctuator/model.py, which is derived from the networks in lingvo/tasks/mt/model.py with an added Inference method for the punctuator task. The actual logic lies mostly in lingvo/tasks/mt/encoder.py and lingvo/tasks/mt/decoder.py.
After defining the input generator and the network, we need to create an model configuration with the specific hyperparameters to use for our model.
Since there is only a single task, we create a subclass of SingleTaskModelParams found in lingvo/core/base_model_params.py. It has the following structure:
Train/Dev/Test methods configure the input generator for the respective datasets.Task method configures the network.The following cell contains the configuration that will be used in this codelab. It can also be found in lingvo/tasks/punctuator/params/codelab.py. The network configuration in the Task classmethod is delegated to lingvo/tasks/mt/params/base_config.py.
Run the cell below to write the file to disk.
In [0]:
%%writefile codelab.py
import input_generator
import os
from lingvo import model_registry
import lingvo.compat as tf
from lingvo.core import base_model_params
from lingvo.tasks.mt import base_config
from lingvo.tasks.punctuator import model
# This base class defines parameters for the input generator for a specific
# dataset. Specific network architectures will be implemented in subclasses.
class BrownCorpusWPM(base_model_params.SingleTaskModelParams):
"""Brown Corpus data with a Word-Piece Model tokenizer."""
# Generated using
# lingvo/tasks/punctuator/tools:download_brown_corpus.
_DATADIR = '/tmp/punctuator_data'
_VOCAB_FILE = 'brown_corpus_wpm.16000.vocab'
# _VOCAB_SIZE needs to be a multiple of 16 because we use a sharded softmax
# with 16 shards.
_VOCAB_SIZE = 16000
def Train(self):
p = input_generator.PunctuatorInput.Params()
p.file_pattern = 'text:' + os.path.join(self._DATADIR, 'train.txt')
p.file_random_seed = 0 # Do not use a fixed seed.
p.file_parallelism = 1 # We only have a single input file.
# The bucket upper bound specifies how to split the input into buckets. We
# train on sequences up to maximum bucket size and discard longer examples.
p.bucket_upper_bound = [10, 20, 30, 60, 120]
# The bucket batch limit determines how many examples are there in each
# batch during training. We reduce the batch size for the buckets that
# have higher upper bound (batches that consist of longer sequences)
# in order to prevent out of memory issues.
# Note that this hyperparameter varies widely based on the model and
# language. Larger models may warrant smaller batches in order to fit in
# memory, for example; and ideographical languages like Chinese may benefit
# from more buckets.
p.bucket_batch_limit = [512, 256, 160, 80, 40]
p.tokenizer.vocab_filepath = self._VOCAB_FILE
p.tokenizer.vocab_size = self._VOCAB_SIZE
p.tokenizer.pad_to_max_length = False
# Set the tokenizer max length slightly longer than the largest bucket to
# discard examples that are longer than we allow.
p.source_max_length = p.bucket_upper_bound[-1] + 2
p.target_max_length = p.bucket_upper_bound[-1] + 2
return p
# There is also a Dev method for dev set params, but we don't have a dev set.
def Test(self):
p = input_generator.PunctuatorInput.Params()
p.file_pattern = 'text:' + os.path.join(self._DATADIR, 'test.txt')
p.file_random_seed = 27182818 # Fix random seed for testing.
# The following two parameters are important if there's more than one input
# file. For this codelab it doesn't actually matter.
p.file_parallelism = 1 # Avoid randomness in testing.
# In order to make exactly one pass over the dev/test sets, we set buffer
# size to 1. Greater numbers may cause inaccurate dev/test scores.
p.file_buffer_size = 1
p.bucket_upper_bound = [10, 20, 30, 60, 120, 200]
p.bucket_batch_limit = [16] * 4 + [4] * 2
p.tokenizer.vocab_filepath = self._VOCAB_FILE
p.tokenizer.vocab_size = self._VOCAB_SIZE
p.tokenizer.pad_to_max_length = False
p.source_max_length = p.bucket_upper_bound[-1] + 2
p.target_max_length = p.bucket_upper_bound[-1] + 2
return p
# This decorator registers the model in the Lingvo model registry.
# This file is lingvo/tasks/punctuator/params/codelab.py,
# so the model will be registered as punctuator.codelab.RNMTModel.
@model_registry.RegisterSingleTaskModel
class RNMTModel(BrownCorpusWPM):
"""RNMT+ Model."""
def Task(self):
p = base_config.SetupRNMTParams(
model.RNMTModel.Params(),
name='punctuator_rnmt',
vocab_size=self._VOCAB_SIZE,
embedding_dim=1024,
hidden_dim=1024,
num_heads=4,
num_encoder_layers=6,
num_decoder_layers=8,
learning_rate=1e-4,
l2_regularizer_weight=1e-5,
lr_warmup_steps=500,
lr_decay_start=400000,
lr_decay_end=1200000,
lr_min=0.5,
ls_uncertainty=0.1,
atten_dropout_prob=0.3,
residual_dropout_prob=0.3,
adam_beta2=0.98,
adam_epsilon=1e-6,
)
p.eval.samples_per_summary = 2466
return p
In [0]:
# Start tensorboard (access at http://localhost:6006)
import os
os.system('lsof -t -i:6006 || tensorboard --logdir=/tmp/punctuator &')
!python3 -m lingvo.trainer --model=codelab.RNMTModel --mode=sync --logdir=/tmp/punctuator --saver_max_to_keep=2 --noenable_asserts --run_locally=gpu
The following cell evaluates the model. In Lingvo, evaluation is meant to be run alongside training as a separate process that periodically looks for the latest checkpoint and evaluates it. There is only one process in Colab so running this cell will evaluate the current checkpoint then it will block indefinitely waiting for the next checkpoint.
In [0]:
!python3 -m lingvo.trainer --model=codelab.RNMTModel --job=evaler_test --logdir=/tmp/punctuator --run_locally=cpu
There is also a Decoder job that can be run the same way. The difference between the Evaler and Decoder varies by model.
In [0]:
import string
from lingvo import compat as tf
from lingvo import model_imports
from lingvo import model_registry
from lingvo.core import inference_graph_exporter
from lingvo.core import predictor
from lingvo.core.ops.hyps_pb2 import Hypothesis
tf.flags.FLAGS.mark_as_parsed()
src = "she asked do you know the way to san jose" #@param {type:'string'}
src = src.lower().translate(str.maketrans('', '', string.punctuation))
print(src)
checkpoint = tf.train.latest_checkpoint('/tmp/punctuator/train')
print('Using checkpoint %s' % checkpoint)
# Run inference
params = model_registry.GetParams('codelab.RNMTModel', 'Test')
inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(params)
pred = predictor.Predictor(inference_graph,
checkpoint=checkpoint,
device_type='cpu')
src_ids, decoded, scores, hyps = pred.Run(
['src_ids', 'topk_decoded', 'topk_scores', 'topk_hyps'], src_strings=[src])
# print(src_ids[0])
for text, score in zip(decoded[0].tolist(), scores[0].tolist()):
print("%.5f: %s" % (score, text))
# for i, hyp in enumerate(hyps[0]):
# print("=======hyp %d=======" % i)
# print(Hypothesis().FromString(hyp))
Footnote: One might wonder why our result places the question mark outside the quotation. This happens because the Brown corpus follows a 1959 US Patent Office precedure of transliterating texts into punch cards, where the closing question mark is always punched before the punctuation, so that the punctuation mark occurs at the end of the sentence. See this link for more details. Our result is just following this pattern in our training data.
For more advanced topics or to get a deeper understanding of Lingvo beyond this codelab, see the paper.