Neural Machine Translation

Author: https://github.com/A-Jacobson/minimal-nmt

Note that the attention implementation provided in this tutorial is not going to work as the solution to this homework. However, it should be useful for you to understand how attention can be implemented as part of the sequence to sequence model.


In [72]:
import math
import torch
import random
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
import torch.nn.functional as F
import torchtext
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from torch.nn.utils import clip_grad_norm
import spacy

Convenience Functions


In [73]:
def sequence_to_text(sequence, field):
    pad = field.vocab.stoi['<pad>']
    return " ".join([field.vocab.itos[int(i)] for i in sequence])

Load Multi30k English/German parallel corpus for NMT

TorchText takes care of tokenization, padding, special character tokens and batching.


In [74]:
from spacy.lang.de import German
from spacy.lang.en import English

def load_dataset(batch_size, device=0):
    spacy_de = German()
    spacy_de_tokenizer = spacy_de.Defaults.create_tokenizer(spacy_de)
    spacy_en = English()
    spacy_en_tokenizer = spacy_en.Defaults.create_tokenizer(spacy_en)

    def tokenize_de(text):
        return [tok.text for tok in spacy_de_tokenizer(text)]

    def tokenize_en(text):
        return [tok.text for tok in spacy_en_tokenizer(text)]

    DE = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>')
    EN = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>')

    train, val, test = Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN))

    DE.build_vocab(train.src)
    EN.build_vocab(train.trg)

    train_iter, val_iter, test_iter = BucketIterator.splits(
        (train, val, test), batch_size=batch_size, device='cpu', repeat=False)
    return train_iter, val_iter, test_iter, DE, EN

Model Inputs

Model inputs are (seq_len, batch_size) Tensors of word indices


In [75]:
train_iter, val_iter, test_iter, DE, EN = load_dataset(batch_size=5)
example_batch = next(iter(train_iter))
example_batch.src, example_batch.trg


Out[75]:
(tensor([[    2,     2,     2,     2,     2],
         [    5,     5,     5,    92,    14],
         [   12,    12,    12,     7,    17],
         [    8,     7,     7,  1422,    10],
         [   16,     6,   238,   103,   117],
         [   18,    72,  9090,     6,    93],
         [18019,    41,  2090,    97,    32],
         [  388,     9,     7,     4,    11],
         [    9,   158,     6,     3,    13],
         [   20,   183,   734,     1,  5356],
         [  615,   221,    10,     1,  1173],
         [   62,    20,   309,     1, 14399],
         [    8,  1717,  5220,     1,     4],
         [   39,    28,  1316,     1,     3],
         [   18,  6618,     4,     1,     1],
         [  479,    11,     3,     1,     1],
         [    7,     4,     1,     1,     1],
         [   19,     3,     1,     1,     1],
         [   36,     1,     1,     1,     1],
         [    4,     1,     1,     1,     1],
         [    3,     1,     1,     1,     1]]),
 tensor([[   2,    2,    2,    2,    2],
         [   6,    6,    6, 1779,    6],
         [  12,   12,   12,    7,   16],
         [  21,    7,    7,  234,    7],
         [   4,    4,    4,   98,   26],
         [1461,   26,   31,    4,   85],
         [ 346,   23,  969,   69,   10],
         [  11,   11,   43,    5,   32],
         [   4,   24,   41,    3,    9],
         [ 335,  153,   23,    1,    4],
         [  68, 5092,   10,    1, 1397],
         [ 144,    7, 1767,    1,  253],
         [   4,   44,    4,    1,   89],
         [ 441,   13,  634,    1,  148],
         [   7,  618,    7,    1,    5],
         [   8,    5,    4,    1,    3],
         [ 147,    3,  654,    1,    1],
         [   3,    1,   14,    1,    1],
         [   1,    1,   30,    1,    1],
         [   1,    1,  533,    1,    1],
         [   1,    1, 1122,    1,    1],
         [   1,    1,    5,    1,    1],
         [   1,    1,    3,    1,    1]]))

We can recover the original text by looking up each index in the vocabularies we build with the load_data function.


In [76]:
print(sequence_to_text(example_batch.src[:, 0], DE))
print(sequence_to_text(example_batch.trg[:, 0], EN))


<sos> Ein Mann , der eine reflektierende Weste und einen Schutzhelm trägt , hält eine Flagge in die Straße . <eos>
<sos> A man wearing a reflective vest and a hard hat holds a flag in the road <eos> <pad> <pad> <pad> <pad> <pad>

Architecture

NMT uses an encoder-decoder architecture to effectively translate source sequences and target sequences that are of different lengths

Encoder

Encodes each word of the source sequence into a hidden_dim feature map. Sometimes called an annotation. Also returns the hidden state of the encoder bi-rnn.


In [77]:
class Encoder(nn.Module):
    def __init__(self, source_vocab_size, embed_dim, hidden_dim,
                 n_layers, dropout):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.embed = nn.Embedding(source_vocab_size, embed_dim, padding_idx=1)
        self.gru = nn.GRU(embed_dim, hidden_dim, n_layers,
                          dropout=dropout, bidirectional=True)

    def forward(self, source, hidden=None):
        embedded = self.embed(source)  # (batch_size, seq_len, embed_dim)
        encoder_out, encoder_hidden = self.gru(
            embedded, hidden)  # (seq_len, batch, hidden_dim*2)
        # sum bidirectional outputs, the other option is to retain concat features
        encoder_out = (encoder_out[:, :, :self.hidden_dim] +
                       encoder_out[:, :, self.hidden_dim:])
        return encoder_out, encoder_hidden

In [78]:
embed_dim = 256
hidden_dim = 512
n_layers = 2
dropout = 0.5

In [79]:
encoder = Encoder(source_vocab_size=len(DE.vocab), embed_dim=embed_dim,
                  hidden_dim=hidden_dim, n_layers=n_layers, dropout=dropout)

In [80]:
encoder_out, encoder_hidden = encoder(example_batch.src)
print('encoder output size: ', encoder_out.size())  # source, batch_size, hidden_dim
print('encoder hidden size: ', encoder_hidden.size()) # n_layers * num_directions, batch_size, hidden_dim


encoder output size:  torch.Size([21, 5, 512])
encoder hidden size:  torch.Size([4, 5, 512])

Attention

Currently the encoder_output is a length 14 sequence and the target is a length 13 sequence. We need to compress the information in the encoder_output into a context_vector which should have all the information the decoder needs to predict the next step of its output. We will use Luong Attention to create this context vector.


In [81]:
class LuongAttention(nn.Module):
    """
    LuongAttention from Effective Approaches to Attention-based Neural Machine Translation
    https://arxiv.org/pdf/1508.04025.pdf
    """

    def __init__(self, dim):
        super(LuongAttention, self).__init__()
        self.W = nn.Linear(dim, dim, bias=False)

    def score(self, decoder_hidden, encoder_out):
        # linear transform encoder out (seq, batch, dim)
        encoder_out = self.W(encoder_out)
        # (batch, seq, dim) | (2, 15, 50)
        encoder_out = encoder_out.permute(1, 0, 2)
        # (2, 15, 50) @ (2, 50, 1)
        return encoder_out @ decoder_hidden.permute(1, 2, 0)

    def forward(self, decoder_hidden, encoder_out):
        energies = self.score(decoder_hidden, encoder_out)
        mask = F.softmax(energies, dim=1)  # batch, seq, 1
        context = encoder_out.permute(
            1, 2, 0) @ mask  # (2, 50, 15) @ (2, 15, 1)
        context = context.permute(2, 0, 1)  # (seq, batch, dim)
        mask = mask.permute(2, 0, 1)  # (seq2, batch, seq1)
        return context, mask

This will normally be part of the decoder as it takes the previous decoder hidden state as input, but just to show the inputs and outputs I will use it here.

We will initialize the Decoder rnn's hidden state with the last hidden state from the encoder. Because the encoder is bi-directional we have to reshape it's hidden state in order to select the layer we want.


In [82]:
attention = LuongAttention(dim=hidden_dim)
context, mask = attention(encoder_hidden[-1:], encoder_out)
print(context.size()) # (1, batch, attention_dim) contect_vector
print(mask.size())  # the weights used to compute weighted sum over encoder out (1, batch, source_len)


torch.Size([1, 5, 512])
torch.Size([1, 5, 21])

Decoder with attention


In [83]:
class Decoder(nn.Module):
    def __init__(self, target_vocab_size, embed_dim, hidden_dim,
                 n_layers, dropout):
        super(Decoder, self).__init__()
        self.n_layers = n_layers
        self.embed = nn.Embedding(target_vocab_size, embed_dim, padding_idx=1)
        self.attention = LuongAttention(hidden_dim)
        self.gru = nn.GRU(embed_dim + hidden_dim, hidden_dim, n_layers,
                          dropout=dropout)
        self.out = nn.Linear(hidden_dim * 2, target_vocab_size)

    def forward(self, output, encoder_out, decoder_hidden):
        """
        decodes one output frame
        """
        embedded = self.embed(output)  # (1, batch, embed_dim)
        context, mask = self.attention(decoder_hidden[:-1], encoder_out)  # 1, 1, 50 (seq, batch, hidden_dim)
        rnn_output, decoder_hidden = self.gru(torch.cat([embedded, context], dim=2),
                                              decoder_hidden)
        output = self.out(torch.cat([rnn_output, context], 2))
        return output, decoder_hidden, mask

In [84]:
decoder = Decoder(target_vocab_size=len(EN.vocab), embed_dim=embed_dim,
                  hidden_dim=hidden_dim, n_layers=n_layers, dropout=dropout)

To translate one word from German to English, the decoder needs:

  1. encoder_outputs
  2. decoder_hidden initially, the last n_layers of encoder_hidden then it's own returned hidden state.
  3. previous_output feed a batch of start of string token (index 2) at the first step.

The attention mask that the decoder returns is not used in training but can be used to visualize where the decoder is "looking" in the input sequence in order to generate its current output.


In [85]:
decoder_hidden = encoder_hidden[-decoder.n_layers:]
start_token = example_batch.trg[:1]
start_token


Out[85]:
tensor([[2, 2, 2, 2, 2]])

In [86]:
output, decoder_hidden, mask = decoder(start_token, encoder_out, decoder_hidden)

In [87]:
print('output size: ', output.size())  # (1, batch, target_vocab) # predicted probability distribution over all possible target words
print('decoder hidden size ', decoder_hidden.size())
print('attention mask size', mask.size())


output size:  torch.Size([1, 5, 10839])
decoder hidden size  torch.Size([2, 5, 512])
attention mask size torch.Size([1, 5, 21])

Decoding Helpers

nmt models use teacher forcing during training and greedy decoding or beam search for inference. In order to accommodate these behaviors, I've made simple helper classes that get output from the decoder using each policy.

The Teacher class sometimes feeds the previous target to the decoder rather than the model's previous prediction. this can help speed convergence but requires targets to be loaded to the helper at each step


In [88]:
class Teacher:
    def __init__(self, teacher_forcing_ratio=0.5):
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.targets = None
        self.maxlen = 0
        
    def load_targets(self, targets):
        self.targets = targets
        self.maxlen = len(targets)

    def generate(self, decoder, encoder_out, encoder_hidden):
        outputs = []
        masks = []
        decoder_hidden = encoder_hidden[-decoder.n_layers:]  # take what we need from encoder
        output = self.targets[0].unsqueeze(0)  # start token
        for t in range(1, self.maxlen):
            output, decoder_hidden, mask = decoder(output, encoder_out, decoder_hidden)
            outputs.append(output)
            masks.append(mask.data)
            output = Variable(output.data.max(dim=2)[1])
            # teacher forcing
            is_teacher = random.random() < self.teacher_forcing_ratio
            if is_teacher:
                output = self.targets[t].unsqueeze(0)      
        return torch.cat(outputs), torch.cat(masks).permute(1, 2, 0)  # batch, src, trg

In [89]:
decode_helper = Teacher()
decode_helper.load_targets(example_batch.trg)
outputs, masks = decode_helper.generate(decoder, encoder_out, encoder_hidden)

Calc loss

reshape outputs and targets, ignore sos token at start of target batch.


In [90]:
F.cross_entropy(outputs.view(-1, outputs.size(2)),
                           example_batch.trg[1:].view(-1), ignore_index=1)


Out[90]:
tensor(9.2903, grad_fn=<NllLossBackward>)

The greedy decoder simply chooses the highest scoring word as output. We cam use the set_maxlen method to generate sequences the same length as our targets to easily check perplexity and bleu score during evaluation steps.


In [91]:
class Greedy:
    def __init__(self, maxlen=20, sos_index=2):
        self.maxlen = maxlen
        self.sos_index = sos_index
        
    def set_maxlen(self, maxlen):
        self.maxlen = maxlen
        
    def generate(self, decoder, encoder_out, encoder_hidden):
        seq, batch, _ = encoder_out.size()
        outputs = []
        masks = []
        decoder_hidden = encoder_hidden[-decoder.n_layers:]  # take what we need from encoder
        output = Variable(torch.zeros(1, batch).long() + self.sos_index)  # start token
        for t in range(self.maxlen):
            output, decoder_hidden, mask = decoder(output, encoder_out, decoder_hidden)
            outputs.append(output)
            masks.append(mask.data)
            output = Variable(output.data.max(dim=2)[1])
        return torch.cat(outputs), torch.cat(masks).permute(1, 2, 0)  # batch, src, trg

In [92]:
decode_helper = Greedy()
decode_helper.set_maxlen(len(example_batch.trg[1:]))
outputs, masks = decode_helper.generate(decoder, encoder_out, encoder_hidden)

In [93]:
outputs.size()


Out[93]:
torch.Size([22, 5, 10839])

In [94]:
F.cross_entropy(outputs.view(-1, outputs.size(2)),
                           example_batch.trg[1:].view(-1), ignore_index=1)


Out[94]:
tensor(9.3025, grad_fn=<NllLossBackward>)

seq2seq wrapper


In [95]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, decoding_helper):
        encoder_out, encoder_hidden = self.encoder(source)
        outputs, masks = decoding_helper.generate(self.decoder, encoder_out, encoder_hidden)
        return outputs, masks

In [96]:
seq2seq = Seq2Seq(encoder, decoder)
decoding_helper = Teacher(teacher_forcing_ratio=0.5)

example iteration with wrapper


In [97]:
decoding_helper.load_targets(example_batch.trg)
outputs, masks = seq2seq(example_batch.src, decode_helper)

In [98]:
outputs.size(), masks.size()


Out[98]:
(torch.Size([22, 5, 10839]), torch.Size([5, 21, 22]))

In [99]:
F.cross_entropy(outputs.view(-1, outputs.size(2)),
                           example_batch.trg[1:].view(-1), ignore_index=1)


Out[99]:
tensor(9.2929, grad_fn=<NllLossBackward>)

In [ ]: