In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Neural Machine Translation with Luong Attention - Tensorflow 2.0

Notebook orignially contributed by: chunml

View source on GitHub

Overview

In this notebook, we will create a neural machine translation model using Sequence-To-Sequence learning and Luong attention mechanism. We will walkthrough the following steps:

  • Create a vanilla Seq2Seq model that can overfit a tiny dataset of 20 English-French sentence pairs
  • Train the vanilla Seq2Seq model on the full dataset
  • Introduce Luong attention mechanism, compare results and visualize attention heatmaps

For a more detailed explanation on Luong attention mechanism and implementation, refer to my blog post at: Neural Machine Translation With Attention Mechanism

Setup

First, we need to install Tensorflow 2.0.


In [0]:
import tensorflow as tf
assert tf.__version__.startswith('2')

Import and prepare data

Next, let's import the necessary packages and define the tiny dataset for overfitting. The tiny dataset contains only 20 pairs of English - French sentences, which were extracted from the original dataset we will use later on.


In [0]:
import numpy as np
import unicodedata
import re

raw_data = (
    ('What a ridiculous concept!', 'Quel concept ridicule !'),
    ('Your idea is not entirely crazy.', "Votre idée n'est pas complètement folle."),
    ("A man's worth lies in what he is.", "La valeur d'un homme réside dans ce qu'il est."),
    ('What he did is very wrong.', "Ce qu'il a fait est très mal."),
    ("All three of you need to do that.", "Vous avez besoin de faire cela, tous les trois."),
    ("Are you giving me another chance?", "Me donnez-vous une autre chance ?"),
    ("Both Tom and Mary work as models.", "Tom et Mary travaillent tous les deux comme mannequins."),
    ("Can I have a few minutes, please?", "Puis-je avoir quelques minutes, je vous prie ?"),
    ("Could you close the door, please?", "Pourriez-vous fermer la porte, s'il vous plaît ?"),
    ("Did you plant pumpkins this year?", "Cette année, avez-vous planté des citrouilles ?"),
    ("Do you ever study in the library?", "Est-ce que vous étudiez à la bibliothèque des fois ?"),
    ("Don't be deceived by appearances.", "Ne vous laissez pas abuser par les apparences."),
    ("Excuse me. Can you speak English?", "Je vous prie de m'excuser ! Savez-vous parler anglais ?"),
    ("Few people know the true meaning.", "Peu de gens savent ce que cela veut réellement dire."),
    ("Germany produced many scientists.", "L'Allemagne a produit beaucoup de scientifiques."),
    ("Guess whose birthday it is today.", "Devine de qui c'est l'anniversaire, aujourd'hui !"),
    ("He acted like he owned the place.", "Il s'est comporté comme s'il possédait l'endroit."),
    ("Honesty will pay in the long run.", "L'honnêteté paye à la longue."),
    ("How do we know this isn't a trap?", "Comment savez-vous qu'il ne s'agit pas d'un piège ?"),
    ("I can't believe you're giving up.", "Je n'arrive pas à croire que vous abandonniez."),
)

Preprocess the data

The raw data can't be used yet. We need to apply some preprocessing steps such as:

  • Convert strings from unicode to ascii
  • Add space before punctuation
  • Filter out unwanted tokens
  • Add start and end tokens to target sentences

In [0]:
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn')


def normalize_string(s):
    s = unicode_to_ascii(s)
    s = re.sub(r'([!.?])', r' \1', s)
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    s = re.sub(r'\s+', r' ', s)
    return s


raw_data_en, raw_data_fr = list(zip(*raw_data))
raw_data_en, raw_data_fr = list(raw_data_en), list(raw_data_fr)
raw_data_en = [normalize_string(data) for data in raw_data_en]
raw_data_fr_in = ['<start> ' + normalize_string(data) for data in raw_data_fr]
raw_data_fr_out = [normalize_string(data) + ' <end>' for data in raw_data_fr]

Tokenize the raw data

Since neural networks only accept numeric arrays as input, we need to tokenize our data, i.e. convert the sentences to integer sequences. This task can be done easily with Keras preprocessing utility class.


In [0]:
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)
data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,
                                                        padding='post')

fr_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
fr_tokenizer.fit_on_texts(raw_data_fr_in)
fr_tokenizer.fit_on_texts(raw_data_fr_out)
data_fr_in = fr_tokenizer.texts_to_sequences(raw_data_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in,
                                                           padding='post')

data_fr_out = fr_tokenizer.texts_to_sequences(raw_data_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out,
                                                            padding='post')

Create tf.data.Dataset

Next, we need to create an instance of tf.data.Dataset, which will help us create batches and iterate over the entire dataset.


In [0]:
BATCH_SIZE = 5
dataset = tf.data.Dataset.from_tensor_slices(
    (data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(20).batch(BATCH_SIZE)

Create the Encoder

Now that we have done with the data preparation step, let's move on to creating the model. The vanilla Seq2Seq model consists of an Encoder and a Decoder.

The Encoder only has an embedding layer and an RNN layer (can be either vanilla RNN, LSTM or GRU). We also need a method to initialize the state (zero-state).


In [0]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, lstm_size):
        super(Encoder, self).__init__()
        self.lstm_size = lstm_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            lstm_size, return_sequences=True, return_state=True)

    def call(self, sequence, states):
        embed = self.embedding(sequence)
        output, state_h, state_c = self.lstm(embed, initial_state=states)

        return output, state_h, state_c

    def init_states(self, batch_size):
        return (tf.zeros([batch_size, self.lstm_size]),
                tf.zeros([batch_size, self.lstm_size]))

Create the Decoder

Next, we will create the Decoder. Basically, the Decoder is similar to the Encoder, with an additional Dense layer to map the RNN output to vocabulary space.


In [0]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, lstm_size):
        super(Decoder, self).__init__()
        self.lstm_size = lstm_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            lstm_size, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, sequence, state):
        embed = self.embedding(sequence)
        lstm_out, state_h, state_c = self.lstm(embed, state)
        logits = self.dense(lstm_out)

        return logits, state_h, state_c

Build the model

We will not use Keras' compile and fit methods this time, so we have to manually build the model. The easiest way to achieve that is to feed in some dummy data. We can also print out the outputs' shapes for debugging purpose.


In [0]:
en_vocab_size = len(en_tokenizer.word_index) + 1
fr_vocab_size = len(fr_tokenizer.word_index) + 1

EMBEDDING_SIZE = 32
LSTM_SIZE = 64

encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)
decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)

initial_states = encoder.init_states(1)
encoder_outputs = encoder(tf.constant([[1, 2, 3]]), initial_states)
decoder_outputs = decoder(tf.constant([[1, 2, 3]]), encoder_outputs[1:])

Create the loss function and the optimizer

Next, let's create the loss function. Since we apply zero padding to the source and target sequences, we need to mask those zeros out when computing the loss.

As for the optimizer, we will use Adam with default setting.


In [0]:
def loss_func(targets, logits):
    crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    mask = tf.math.logical_not(tf.math.equal(targets, 0))
    mask = tf.cast(mask, dtype=tf.int64)
    loss = crossentropy(targets, logits, sample_weight=mask)

    return loss

optimizer = tf.keras.optimizers.Adam()

The train_step function

Technically, that's what we call a function which run a full iteration, i.e. a forward pass followed by a backward pass. Since Tensorflow 2.0, we can use the @tf.function decorator to explicitly put a particular piece of code in static graph execution. If you want to debug, don't forget to remove it.


In [0]:
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
    with tf.GradientTape() as tape:
        en_outputs = encoder(source_seq, en_initial_states)
        en_states = en_outputs[1:]
        de_states = en_states

        de_outputs = decoder(target_seq_in, de_states)
        logits = de_outputs[0]
        loss = loss_func(target_seq_out, logits)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss

The predict function

It's always a good idea to see how well the model can do during training, since only monitoring the loss values doesn't tell us that much. Basically, the predict function only does a forward pass. However, on the Decoder's side, we will start with the start token. At every next time step, the output of the previous step will be used as the input of the current step.


In [0]:
def predict(test_source_text=None):
    if test_source_text is None:
        test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
    print(test_source_text)
    test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
    print(test_source_seq)

    en_initial_states = encoder.init_states(1)
    en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)

    de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
    de_state_h, de_state_c = en_outputs[1:]
    out_words = []

    while True:
        de_output, de_state_h, de_state_c = decoder(
            de_input, (de_state_h, de_state_c))
        de_input = tf.argmax(de_output, -1)
        out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])

        if out_words[-1] == '<end>' or len(out_words) >= 20:
            break

    print(' '.join(out_words))

The training loop

Here comes the training loop. We will train for 300 epochs and print out the loss value together with the translation result of a random English sentence (within the training data). Doing so will help us know if something went wrong along the way.

At first, the model made only non-sense translations, but we can see that it keeps getting better and better over time.


In [0]:
NUM_EPOCHS = 300
for e in range(NUM_EPOCHS):
    en_initial_states = encoder.init_states(BATCH_SIZE)
    
    predict()

    for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
        loss = train_step(source_seq, target_seq_in,
                          target_seq_out, en_initial_states)

    print('Epoch {} Loss {:.4f}'.format(e + 1, loss.numpy()))

Let's do a full inference

Finally, let's have the model translate all 20 training pairs. We can see that at this point, the model has completely learned by heart the training data, which is what we wanted at the beginning.


In [0]:
test_sents = (
    'What a ridiculous concept!',
    'Your idea is not entirely crazy.',
    "A man's worth lies in what he is.",
    'What he did is very wrong.',
    "All three of you need to do that.",
    "Are you giving me another chance?",
    "Both Tom and Mary work as models.",
    "Can I have a few minutes, please?",
    "Could you close the door, please?",
    "Did you plant pumpkins this year?",
    "Do you ever study in the library?",
    "Don't be deceived by appearances.",
    "Excuse me. Can you speak English?",
    "Few people know the true meaning.",
    "Germany produced many scientists.",
    "Guess whose birthday it is today.",
    "He acted like he owned the place.",
    "Honesty will pay in the long run.",
    "How do we know this isn't a trap?",
    "I can't believe you're giving up.",
)

for test_sent in test_sents:
    test_sequence = normalize_string(test_sent)
    predict(test_sequence)

Let's train Seq2Seq on the full dataset

Now, we have created a fully functional Seq2Seq model. Let's use it to train on the full dataset, which contains approximately 160K English - French pairs. That's huge!

But don't worry since we don't have to make any change to the current workflow. We only need to download the full dataset, tweak the networks' hyperparameters, reinitialize everything and we are ready to train.

The training is gonna take a (long) while, so be patient!


In [0]:
import requests
import os
from zipfile import ZipFile

MODE = 'train'
URL = 'http://www.manythings.org/anki/fra-eng.zip'
FILENAME = 'fra-eng.zip'
BATCH_SIZE = 64
EMBEDDING_SIZE = 256
LSTM_SIZE = 512
NUM_EPOCHS = 15


# ================= DOWNLOAD AND READ THE DATA ======================
def maybe_download_and_read_file(url, filename):
    if not os.path.exists(filename):
        session = requests.Session()
        response = session.get(url, stream=True)

        CHUNK_SIZE = 32768
        with open(filename, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:
                    f.write(chunk)

    zipf = ZipFile(filename)
    filename = zipf.namelist()
    with zipf.open('fra.txt') as f:
        lines = f.read()

    return lines


lines = maybe_download_and_read_file(URL, FILENAME)
lines = lines.decode('utf-8')

raw_data = []
for line in lines.split('\n'):
    raw_data.append(line.split('\t'))

raw_data = raw_data[:-1]


# ================= TOKENIZATION AND ZERO PADDING ===================
raw_data_en, raw_data_fr = zip(*raw_data)
raw_data_en = [normalize_string(data) for data in raw_data_en]
raw_data_fr_in = ['<start> ' + normalize_string(data) for data in raw_data_fr]
raw_data_fr_out = [normalize_string(data) + ' <end>' for data in raw_data_fr]

en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)
data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,
                                                        padding='post')

fr_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
fr_tokenizer.fit_on_texts(raw_data_fr_in)
fr_tokenizer.fit_on_texts(raw_data_fr_out)
data_fr_in = fr_tokenizer.texts_to_sequences(raw_data_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in,
                                                           padding='post')

data_fr_out = fr_tokenizer.texts_to_sequences(raw_data_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out,
                                                            padding='post')

dataset = tf.data.Dataset.from_tensor_slices(
    (data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(len(raw_data_en)).batch(
    BATCH_SIZE, drop_remainder=True)

# ======================== BUILD THE MODEL ==========================
en_vocab_size = len(en_tokenizer.word_index) + 1
encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)
initial_state = encoder.init_states(1)
test_encoder_output = encoder(tf.constant(
    [[1, 23, 4, 5, 0, 0]]), initial_state)

fr_vocab_size = len(fr_tokenizer.word_index) + 1
decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)
de_initial_state = test_encoder_output[1:]
test_decoder_output = decoder(tf.constant(
    [[1, 3, 5, 7, 9, 0, 0, 0]]), de_initial_state)

# ================ ADD GRADIENT CLIPPING OPTION =====================
optimizer = tf.keras.optimizers.Adam(clipnorm=5.0)

# ================== THIS NEEDS TO BE RE-RUN ========================
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
    with tf.GradientTape() as tape:
        en_outputs = encoder(source_seq, en_initial_states)
        en_states = en_outputs[1:]
        de_states = en_states

        de_outputs = decoder(target_seq_in, de_states)
        logits = de_outputs[0]
        loss = loss_func(target_seq_out, logits)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss

# ===================== THE TRAINING LOOP ===========================
for e in range(NUM_EPOCHS):
    en_initial_states = encoder.init_states(BATCH_SIZE)
    
    predict()

    for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
        loss = train_step(source_seq, target_seq_in,
                          target_seq_out, en_initial_states)
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(
                e + 1, batch, loss.numpy()))

It's time for test

So after a while, our vanilla Seq2Seq has completed its training. Let's see how well it can translate. For sake of simplicity, we will use the same 20 English - French pairs that we defined above.


In [0]:
test_sents = (
    'What a ridiculous concept!',
    'Your idea is not entirely crazy.',
    "A man's worth lies in what he is.",
    'What he did is very wrong.',
    "All three of you need to do that.",
    "Are you giving me another chance?",
    "Both Tom and Mary work as models.",
    "Can I have a few minutes, please?",
    "Could you close the door, please?",
    "Did you plant pumpkins this year?",
    "Do you ever study in the library?",
    "Don't be deceived by appearances.",
    "Excuse me. Can you speak English?",
    "Few people know the true meaning.",
    "Germany produced many scientists.",
    "Guess whose birthday it is today.",
    "He acted like he owned the place.",
    "Honesty will pay in the long run.",
    "How do we know this isn't a trap?",
    "I can't believe you're giving up.",
)

for test_sent in test_sents:
    test_sequence = normalize_string(test_sent)
    predict(test_sequence)

Hmm, somehow acceptable, I think. But obviously, there are a lot of rooms for improvement. We're gonna look at the vanilla Seq2Seq's problems and see what we can do.

Let's talk about Attention Mechanisms

As we saw from the result above, while the vanilla Seq2Seq can overfit a small dataset, it struggled with learning from a much larger one, especially long sequences.

A solution for that? Well, what if all the time steps within the Decoder gain access to the Encoder's state and have the right to decide which part to focus onto? That's the idea behind attention mechanisms.

What is an attention layer made of?

So, we understand the idea of having an attention mechanism. Let's see things from the inside. Technically, we need to compute the following:

  • The alignment vector
  • The context vector

The alignment vector has the same length as the source sequence. Each of its values is the score (or the probability) of the corresponding word within the source sequence.

The context vector, on the other hand, is simply the weighted average of the Encoder's output. It will then be used by the Decoder to compute the final output.

If you want a fancier explanation, please refer to my blog post (link above). I have a bunch of images there.

Below is how we implement Luong attention in Python.


In [0]:
class LuongAttention(tf.keras.Model):
    def __init__(self, rnn_size, attention_func):
        super(LuongAttention, self).__init__()
        self.attention_func = attention_func

        if attention_func not in ['dot', 'general', 'concat']:
            raise ValueError(
                'Unknown attention score function! Must be either dot, general or concat.')

        if attention_func == 'general':
            # General score function
            self.wa = tf.keras.layers.Dense(rnn_size)
        elif attention_func == 'concat':
            # Concat score function
            self.wa = tf.keras.layers.Dense(rnn_size, activation='tanh')
            self.va = tf.keras.layers.Dense(1)

    def call(self, decoder_output, encoder_output):
        if self.attention_func == 'dot':
            # Dot score function: decoder_output (dot) encoder_output
            # decoder_output has shape: (batch_size, 1, rnn_size)
            # encoder_output has shape: (batch_size, max_len, rnn_size)
            # => score has shape: (batch_size, 1, max_len)
            score = tf.matmul(decoder_output, encoder_output, transpose_b=True)
        elif self.attention_func == 'general':
            # General score function: decoder_output (dot) (Wa (dot) encoder_output)
            # decoder_output has shape: (batch_size, 1, rnn_size)
            # encoder_output has shape: (batch_size, max_len, rnn_size)
            # => score has shape: (batch_size, 1, max_len)
            score = tf.matmul(decoder_output, self.wa(
                encoder_output), transpose_b=True)
        elif self.attention_func == 'concat':
            # Concat score function: va (dot) tanh(Wa (dot) concat(decoder_output + encoder_output))
            # Decoder output must be broadcasted to encoder output's shape first
            decoder_output = tf.tile(
                decoder_output, [1, encoder_output.shape[1], 1])

            # Concat => Wa => va
            # (batch_size, max_len, 2 * rnn_size) => (batch_size, max_len, rnn_size) => (batch_size, max_len, 1)
            score = self.va(
                self.wa(tf.concat((decoder_output, encoder_output), axis=-1)))

            # Transpose score vector to have the same shape as other two above
            # (batch_size, max_len, 1) => (batch_size, 1, max_len)
            score = tf.transpose(score, [0, 2, 1])

        # alignment a_t = softmax(score)
        alignment = tf.nn.softmax(score, axis=2)

        # context vector c_t is the weighted average sum of encoder output
        context = tf.matmul(alignment, encoder_output)

        return context, alignment

Update the Decoder to use Luong Attention

In order to use the attention above, we need to make a couple of small changes. We will start off with the Decoder.


In [0]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, rnn_size, attention_func):
        super(Decoder, self).__init__()
        self.attention = LuongAttention(rnn_size, attention_func)
        self.rnn_size = rnn_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            rnn_size, return_sequences=True, return_state=True)
        self.wc = tf.keras.layers.Dense(rnn_size, activation='tanh')
        self.ws = tf.keras.layers.Dense(vocab_size)

    def call(self, sequence, state, encoder_output):
        # Remember that the input to the decoder
        # is now a batch of one-word sequences,
        # which means that its shape is (batch_size, 1)
        embed = self.embedding(sequence)

        # Therefore, the lstm_out has shape (batch_size, 1, rnn_size)
        lstm_out, state_h, state_c = self.lstm(embed, initial_state=state)

        # Use self.attention to compute the context and alignment vectors
        # context vector's shape: (batch_size, 1, rnn_size)
        # alignment vector's shape: (batch_size, 1, source_length)
        context, alignment = self.attention(lstm_out, encoder_output)

        # Combine the context vector and the LSTM output
        # Before combined, both have shape of (batch_size, 1, rnn_size),
        # so let's squeeze the axis 1 first
        # After combined, it will have shape of (batch_size, 2 * rnn_size)
        lstm_out = tf.concat(
            [tf.squeeze(context, 1), tf.squeeze(lstm_out, 1)], 1)

        # lstm_out now has shape (batch_size, rnn_size)
        lstm_out = self.wc(lstm_out)

        # Finally, it is converted back to vocabulary space: (batch_size, vocab_size)
        logits = self.ws(lstm_out)

        return logits, state_h, state_c, alignment

Rebuild the model

Next, let's rebuild the model. Notice that we must now pass the Encoder's output to the Decoder.


In [0]:
# Set the score function to compute alignment vectors
# Can choose between 'dot', 'general' or 'concat'
ATTENTION_FUNC = 'concat'

encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)

decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, LSTM_SIZE, ATTENTION_FUNC)

# These lines can be used for debugging purpose
# Or can be seen as a way to build the models
initial_state = encoder.init_states(1)
encoder_outputs = encoder(tf.constant([[1]]), initial_state)
decoder_outputs = decoder(tf.constant(
    [[1]]), encoder_outputs[1:], encoder_outputs[0])

Modify the train_step function

Next, we need to modify the function for training to compute the Decoder's output one step at a time, which means that we need to explicitly create a loop.


In [0]:
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
    loss = 0
    with tf.GradientTape() as tape:
        en_outputs = encoder(source_seq, en_initial_states)
        en_states = en_outputs[1:]
        de_state_h, de_state_c = en_states

        # We need to create a loop to iterate through the target sequences
        for i in range(target_seq_out.shape[1]):
            # Input to the decoder must have shape of (batch_size, length)
            # so we need to expand one dimension
            decoder_in = tf.expand_dims(target_seq_in[:, i], 1)
            logit, de_state_h, de_state_c, _ = decoder(
                decoder_in, (de_state_h, de_state_c), en_outputs[0])

            # The loss is now accumulated through the whole batch
            loss += loss_func(target_seq_out[:, i], logit)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss / target_seq_out.shape[1]

Modify the predict function

The change to make in the predict function is simple, we need to collect and return the alignment vectors. They will help generate some fancy attention heatmaps for visualization.


In [0]:
def predict(test_source_text=None):
    if test_source_text is None:
        test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
    print(test_source_text)
    test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
    print(test_source_seq)

    en_initial_states = encoder.init_states(1)
    en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)

    de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
    de_state_h, de_state_c = en_outputs[1:]
    out_words = []
    alignments = []

    while True:
        de_output, de_state_h, de_state_c, alignment = decoder(
            de_input, (de_state_h, de_state_c), en_outputs[0])
        de_input = tf.expand_dims(tf.argmax(de_output, -1), 0)
        out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])

        alignments.append(alignment.numpy())

        if out_words[-1] == '<end>' or len(out_words) >= 20:
            break

    print(' '.join(out_words))
    return np.array(alignments), test_source_text.split(' '), out_words

Let's train Seq2Seq with Luong attention

The training loop is basically the same, although I recommend putting the call to predict function inside a try-catch to prevent crashing when the model's prediction is zero.


In [0]:
for e in range(NUM_EPOCHS):
    en_initial_states = encoder.init_states(BATCH_SIZE)
    
    for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
        loss = train_step(source_seq, target_seq_in,
                          target_seq_out, en_initial_states)
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(
                e + 1, batch, loss.numpy()))

    try:
        predict()
        predict("How are you today ?")
    except Exception:
        continue

Let's make some translations

Similar to what we did with the vanilla Seq2Seq, it's time to see how the model with Luong attention can translate the same 20 pairs of English - French sentences.

Furthermore, we can also visualize where the model focused on when making translation.


In [0]:
import matplotlib.pyplot as plt
import imageio

if not os.path.exists('heatmap'):
    os.makedirs('heatmap')

test_sents = (
    'What a ridiculous concept!',
    'Your idea is not entirely crazy.',
    "A man's worth lies in what he is.",
    'What he did is very wrong.',
    "All three of you need to do that.",
    "Are you giving me another chance?",
    "Both Tom and Mary work as models.",
    "Can I have a few minutes, please?",
    "Could you close the door, please?",
    "Did you plant pumpkins this year?",
    "Do you ever study in the library?",
    "Don't be deceived by appearances.",
    "Excuse me. Can you speak English?",
    "Few people know the true meaning.",
    "Germany produced many scientists.",
    "Guess whose birthday it is today.",
    "He acted like he owned the place.",
    "Honesty will pay in the long run.",
    "How do we know this isn't a trap?",
    "I can't believe you're giving up.",
)

filenames = []

for i, test_sent in enumerate(test_sents):
    test_sequence = normalize_string(test_sent)
    alignments, source, prediction = predict(test_sequence)
    attention = np.squeeze(alignments, (1, 2))
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='jet')
    ax.set_xticklabels([''] + source, rotation=90)
    ax.set_yticklabels([''] + prediction)

    filenames.append('heatmap/test_{}.png'.format(i))
    plt.savefig('heatmap/test_{}.png'.format(i))
    plt.close()

with imageio.get_writer('translation_heatmaps.gif', mode='I', duration=2) as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

As you can see, we can tell by feeling that the model with Luong attention did make better translations than the vanilla model. For a more accurate metric, you can compute the BLEU scores and compare between the two.

As for the GIF image, we can download to our local machine and use any program of your choice to open it.


In [0]:
from google.colab import files

files.download('translation_heatmaps.gif')

Final words

And that is that. We have finished creating a neural machine translation model with the renown Seq2Seq model and Luong-style attention mechanism.

For a more detailed explanation, feel free to jump to my blog post at Neural Machine Translation With Attention Mechanism.

If you have any problems, don't hesitate to let me know. Thank you for reading such a long post.

Reference

  • Neural Machine Translation With Attention Mechanism: link
  • Effective Approaches to Attention-based Neural Machine Translation paper (Luong attention): link
  • Tensorflow Neural Machine Translation with (Bahdanau) Attention tutorial: link
  • Luong’s Neural Machine Translation repository: link