In [0]:
##### Copyright 2020 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

TensorFlow Addons Networks : Sequence-to-Sequence NMT with Attention Mechanism

View on TensorFlow.org View source on GitHub Download notebook

Overview

This notebook gives a brief introduction into the Sequence to Sequence Model Architecture In this noteboook we broadly cover four essential topics necessary for Neural Machine Translation:

  • Data cleaning
  • Data preparation
  • Neural Translation Model with Attention
  • Final Translation

The basic idea behind such a model though, is only the encoder-decoder architecture. These networks are usually used for a variety of tasks like text-summerization, Machine translation, Image Captioning, etc. This tutorial provideas a hands-on understanding of the concept, explaining the technical jargons wherever necessary. We focus on the task of Neural Machine Translation (NMT) which was the very first testbed for seq2seq models.

Setup

Additional Resources:

These are a lst of resurces you must install in order to allow you to run this notebook:

  1. German-English Dataset

The dataset should be downloaded, in order to compile this notebook, the embeddings can be used, as they are pretrained. Though, we carry out our own training here !!


In [0]:
#download data
print("Downloading Dataset:")
!wget --quiet http://www.manythings.org/anki/deu-eng.zip
!unzip deu-eng.zip


Downloading Dataset:
Archive:  deu-eng.zip
  inflating: deu.txt                 
  inflating: _about.txt              
Downloading Dataset:
Archive:  deu-eng.zip
replace deu.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [0]:
import csv
import string
import re
from pickle import dump
from unicodedata import normalize
from numpy import array
import itertools
from pickle import load
from tensorflow.keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Embedding
from pickle import load
from numpy import array
from numpy import argmax
import tensorflow as tf
from keras.models import load_model
from nltk.translate.bleu_score import corpus_bleu
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa


Using TensorFlow backend.

Data Cleaning

Our data set is a German-English translation dataset. It contains 152,820 pairs of English to German phases, one pair per line with a tab separating the language. These dataset though organized needs cleaning before we can work on it. This will enable us to remove unnecessary bumps that may come in during the training.


In [0]:
# load doc into memory
def load_documnet(filename):
# open the file as read only
  file = open(filename, mode='rt', encoding='utf-8')
  # read all text
  text = file.read()
  # close the file
  file.close()
  return text

# split a loaded document into sentences
def doc_sep_pair(doc):
  lines = doc.strip().split('\n')
  pairs = [line.split('\t') for line in  lines]
  return pairs

# clean a list of lines
def clean_sentences(lines):
  cleaned = list()
  re_print = re.compile('[^%s]' % re.escape(string.printable))
  # prepare translation table 
  table = str.maketrans('', '', string.punctuation)
  for pair in lines:
    clean_pair = list()
    for line in pair:
      # normalizing unicode characters
      line = normalize('NFD', line).encode('ascii', 'ignore')
      line = line.decode('UTF-8')
      # tokenize on white space
      line = line.split()
      # convert to lowercase
      line = [word.lower() for word in line]
      # removing punctuation
      line = [word.translate(table) for word in line]
      # removing non-printable chars form each token
      line = [re_print.sub('', w) for w in line]
      # removing tokens with numbers
      line = [word for word in line if word.isalpha()]

      line.insert(0,'<start> ')
      line.append(' <end>')
      # store as string
      clean_pair.append(' '.join(line))
    cleaned.append(clean_pair)
  return array(cleaned)

Saving the Cleaned Dataset


In [0]:
# load dataset
filename = 'deu.txt' #change filename if necessary
doc = load_documnet(filename)

#clean sentences and save clean data
pairs = doc_sep_pair(doc)
clean_sentences = clean_sentences(pairs)
raw_data = clean_sentences
data = raw_data[:10000, :2] 
import numpy as np
raw_data_en = list()
raw_data_ge = list()
for data1 in data:
  raw_data_en.append(data1[0]),raw_data_ge.append(data1[1])

Tokenization


In [0]:
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)

data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,padding='post')

ge_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
ge_tokenizer.fit_on_texts(raw_data_ge)

data_ge = ge_tokenizer.texts_to_sequences(raw_data_ge)
data_ge = tf.keras.preprocessing.sequence.pad_sequences(data_ge,padding='post')

In [0]:
def max_len(tensor):
    #print( np.argmax([len(t) for t in tensor]))
    return max( len(t) for t in tensor)

Model Parameters


In [0]:
X_train,  X_test, Y_train, Y_test = train_test_split(data_en,data_ge,test_size=0.2)
BATCH_SIZE = 64
BUFFER_SIZE = len(X_train)
steps_per_epoch = BUFFER_SIZE//BATCH_SIZE
embedding_dims = 256
rnn_units = 1024
dense_units = 1024
Dtype = tf.float32   #used to initialize DecoderCell Zero state

Dataset Prepration


In [0]:
Tx = max_len(data_en)
Ty = max_len(data_ge)  

input_vocab_size = len(en_tokenizer.word_index)+1  
output_vocab_size = len(ge_tokenizer.word_index)+ 1
dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
example_X, example_Y = next(iter(dataset))
#print(example_X.shape) 
#print(example_Y.shape)

Defining NMT Model


In [0]:
#ENCODER
class EncoderNetwork(tf.keras.Model):
    def __init__(self,input_vocab_size,embedding_dims, rnn_units ):
        super().__init__()
        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,
                                                           output_dim=embedding_dims)
        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True, 
                                                     return_state=True )
    
#DECODER
class DecoderNetwork(tf.keras.Model):
    def __init__(self,output_vocab_size, embedding_dims, rnn_units):
        super().__init__()
        self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,
                                                           output_dim=embedding_dims) 
        self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
        self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
        # Sampler
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()
        # Create attention mechanism with memory = None
        self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
        self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)
        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler= self.sampler,
                                                output_layer=self.dense_layer)

    def build_attention_mechanism(self, units,memory, memory_sequence_length):
        return tfa.seq2seq.LuongAttention(units, memory = memory, 
                                          memory_sequence_length=memory_sequence_length)
        #return tfa.seq2seq.BahdanauAttention(units, memory = memory, memory_sequence_length=memory_sequence_length)

    # wrap decodernn cell  
    def build_rnn_cell(self, batch_size ):
        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,
                                                attention_layer_size=dense_units)
        return rnn_cell
    
    def build_decoder_initial_state(self, batch_size, encoder_state,Dtype):
        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size, 
                                                                dtype = Dtype)
        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 
        return decoder_initial_state

encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)
decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)
optimizer = tf.keras.optimizers.Adam()

Initializing Training functions


In [0]:
def loss_function(y_pred, y):
   
    #shape of y [batch_size, ty]
    #shape of y_pred [batch_size, Ty, output_vocab_size] 
    sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                                                  reduction='none')
    loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)
    mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1
    mask = tf.cast(mask, dtype=loss.dtype)
    loss = mask* loss
    loss = tf.reduce_mean(loss)
    return loss


def train_step(input_batch, output_batch,encoder_initial_cell_state):
    #initialize loss = 0
    loss = 0
    with tf.GradientTape() as tape:
        encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)
        a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp, 
                                                        initial_state =encoder_initial_cell_state)

        #[last step activations,last memory_state] of encoder passed as input to decoder Network
        
         
        # Prepare correct Decoder input & output sequence data
        decoder_input = output_batch[:,:-1] # ignore <end>
        #compare logits with timestepped +1 version of decoder_input
        decoder_output = output_batch[:,1:] #ignore <start>


        # Decoder Embeddings
        decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)

        #Setting up decoder memory from encoder output and Zero State for AttentionWrapperState
        decoderNetwork.attention_mechanism.setup_memory(a)
        decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,
                                                                           encoder_state=[a_tx, c_tx],
                                                                           Dtype=tf.float32)
        
        #BasicDecoderOutput        
        outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state=decoder_initial_state,
                                               sequence_length=BATCH_SIZE*[Ty-1])

        logits = outputs.rnn_output
        #Calculate loss

        loss = loss_function(logits, decoder_output)

    #Returns the list of all layer variables / weights.
    variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables  
    # differentiate loss wrt variables
    gradients = tape.gradient(loss, variables)

    #grads_and_vars – List of(gradient, variable) pairs.
    grads_and_vars = zip(gradients,variables)
    optimizer.apply_gradients(grads_and_vars)
    return loss

In [0]:
#RNN LSTM hidden and memory state initializer
def initialize_initial_state():
        return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]

Training


In [0]:
epochs = 15
for i in range(1, epochs+1):

    encoder_initial_cell_state = initialize_initial_state()
    total_loss = 0.0

    for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)
        total_loss += batch_loss
        if (batch+1)%5 == 0:
            print("total loss: {} epoch {} batch {} ".format(batch_loss.numpy(), i, batch+1))


total loss: 3.9475886821746826 epoch 1 batch 5 
total loss: 2.912675380706787 epoch 1 batch 10 
total loss: 2.2815799713134766 epoch 1 batch 15 
total loss: 2.183105230331421 epoch 1 batch 20 
total loss: 2.2029659748077393 epoch 1 batch 25 
total loss: 2.1441750526428223 epoch 1 batch 30 
total loss: 2.0407521724700928 epoch 1 batch 35 
total loss: 2.010983943939209 epoch 1 batch 40 
total loss: 2.073960542678833 epoch 1 batch 45 
total loss: 1.990903615951538 epoch 1 batch 50 
total loss: 2.074843406677246 epoch 1 batch 55 
total loss: 2.011075496673584 epoch 1 batch 60 
total loss: 1.9887497425079346 epoch 1 batch 65 
total loss: 1.892216682434082 epoch 1 batch 70 
total loss: 1.9781100749969482 epoch 1 batch 75 
total loss: 1.858034372329712 epoch 1 batch 80 
total loss: 1.7822014093399048 epoch 1 batch 85 
total loss: 1.8063167333602905 epoch 1 batch 90 
total loss: 1.7766847610473633 epoch 1 batch 95 
total loss: 1.8258825540542603 epoch 1 batch 100 
total loss: 1.8362295627593994 epoch 1 batch 105 
total loss: 1.7136967182159424 epoch 1 batch 110 
total loss: 1.9544591903686523 epoch 1 batch 115 
total loss: 1.8494930267333984 epoch 1 batch 120 
total loss: 1.6849143505096436 epoch 1 batch 125 
total loss: 1.6526412963867188 epoch 2 batch 5 
total loss: 1.674989938735962 epoch 2 batch 10 
total loss: 1.5851391553878784 epoch 2 batch 15 
total loss: 1.6815704107284546 epoch 2 batch 20 
total loss: 1.6417633295059204 epoch 2 batch 25 
total loss: 1.697832703590393 epoch 2 batch 30 
total loss: 1.7252113819122314 epoch 2 batch 35 
total loss: 1.537140965461731 epoch 2 batch 40 
total loss: 1.6764580011367798 epoch 2 batch 45 
total loss: 1.453371286392212 epoch 2 batch 50 
total loss: 1.6771221160888672 epoch 2 batch 55 
total loss: 1.5605512857437134 epoch 2 batch 60 
total loss: 1.6059997081756592 epoch 2 batch 65 
total loss: 1.467176079750061 epoch 2 batch 70 
total loss: 1.609415054321289 epoch 2 batch 75 
total loss: 1.5329309701919556 epoch 2 batch 80 
total loss: 1.6187160015106201 epoch 2 batch 85 
total loss: 1.5867189168930054 epoch 2 batch 90 
total loss: 1.5069472789764404 epoch 2 batch 95 
total loss: 1.64217209815979 epoch 2 batch 100 
total loss: 1.5193077325820923 epoch 2 batch 105 
total loss: 1.5160114765167236 epoch 2 batch 110 
total loss: 1.516736626625061 epoch 2 batch 115 
total loss: 1.4899837970733643 epoch 2 batch 120 
total loss: 1.6041948795318604 epoch 2 batch 125 
total loss: 1.502350926399231 epoch 3 batch 5 
total loss: 1.360275149345398 epoch 3 batch 10 
total loss: 1.3449124097824097 epoch 3 batch 15 
total loss: 1.378374457359314 epoch 3 batch 20 
total loss: 1.4500166177749634 epoch 3 batch 25 
total loss: 1.3589526414871216 epoch 3 batch 30 
total loss: 1.3434584140777588 epoch 3 batch 35 
total loss: 1.2667752504348755 epoch 3 batch 40 
total loss: 1.4497849941253662 epoch 3 batch 45 
total loss: 1.5071169137954712 epoch 3 batch 50 
total loss: 1.344785451889038 epoch 3 batch 55 
total loss: 1.4199110269546509 epoch 3 batch 60 
total loss: 1.3664555549621582 epoch 3 batch 65 
total loss: 1.3798571825027466 epoch 3 batch 70 
total loss: 1.4127501249313354 epoch 3 batch 75 
total loss: 1.3040590286254883 epoch 3 batch 80 
total loss: 1.4017330408096313 epoch 3 batch 85 
total loss: 1.4011389017105103 epoch 3 batch 90 
total loss: 1.295208215713501 epoch 3 batch 95 
total loss: 1.3480514287948608 epoch 3 batch 100 
total loss: 1.2870609760284424 epoch 3 batch 105 
total loss: 1.4333269596099854 epoch 3 batch 110 
total loss: 1.3486473560333252 epoch 3 batch 115 
total loss: 1.26927649974823 epoch 3 batch 120 
total loss: 1.2078845500946045 epoch 3 batch 125 
total loss: 1.1198420524597168 epoch 4 batch 5 
total loss: 1.0763131380081177 epoch 4 batch 10 
total loss: 1.1939853429794312 epoch 4 batch 15 
total loss: 1.2020100355148315 epoch 4 batch 20 
total loss: 1.0719928741455078 epoch 4 batch 25 
total loss: 1.124625325202942 epoch 4 batch 30 
total loss: 1.1307220458984375 epoch 4 batch 35 
total loss: 1.1385688781738281 epoch 4 batch 40 
total loss: 1.1286941766738892 epoch 4 batch 45 
total loss: 1.1118738651275635 epoch 4 batch 50 
total loss: 1.0924361944198608 epoch 4 batch 55 
total loss: 1.2378876209259033 epoch 4 batch 60 
total loss: 1.1472713947296143 epoch 4 batch 65 
total loss: 1.1867095232009888 epoch 4 batch 70 
total loss: 1.1062105894088745 epoch 4 batch 75 
total loss: 1.0883691310882568 epoch 4 batch 80 
total loss: 1.1805391311645508 epoch 4 batch 85 
total loss: 1.3100593090057373 epoch 4 batch 90 
total loss: 1.199307918548584 epoch 4 batch 95 
total loss: 1.1042678356170654 epoch 4 batch 100 
total loss: 1.0394186973571777 epoch 4 batch 105 
total loss: 1.1532882452011108 epoch 4 batch 110 
total loss: 1.0915677547454834 epoch 4 batch 115 
total loss: 1.0750417709350586 epoch 4 batch 120 
total loss: 1.0421984195709229 epoch 4 batch 125 
total loss: 1.0400830507278442 epoch 5 batch 5 
total loss: 1.0274797677993774 epoch 5 batch 10 
total loss: 0.8721328973770142 epoch 5 batch 15 
total loss: 0.9050451517105103 epoch 5 batch 20 
total loss: 0.9007365107536316 epoch 5 batch 25 
total loss: 0.8890656232833862 epoch 5 batch 30 
total loss: 0.8846011161804199 epoch 5 batch 35 
total loss: 0.8554547429084778 epoch 5 batch 40 
total loss: 1.1025922298431396 epoch 5 batch 45 
total loss: 0.9758970141410828 epoch 5 batch 50 
total loss: 1.0573564767837524 epoch 5 batch 55 
total loss: 0.9744541049003601 epoch 5 batch 60 
total loss: 0.9071753621101379 epoch 5 batch 65 
total loss: 0.970922589302063 epoch 5 batch 70 
total loss: 0.9922286868095398 epoch 5 batch 75 
total loss: 0.8951885104179382 epoch 5 batch 80 
total loss: 1.0515273809432983 epoch 5 batch 85 
total loss: 0.9692702293395996 epoch 5 batch 90 
total loss: 0.8851386904716492 epoch 5 batch 95 
total loss: 1.0359522104263306 epoch 5 batch 100 
total loss: 0.9581290483474731 epoch 5 batch 105 
total loss: 0.9426918029785156 epoch 5 batch 110 
total loss: 0.9563409686088562 epoch 5 batch 115 
total loss: 0.9106627702713013 epoch 5 batch 120 
total loss: 0.9571183919906616 epoch 5 batch 125 
total loss: 0.6938820481300354 epoch 6 batch 5 
total loss: 0.760671854019165 epoch 6 batch 10 
total loss: 0.699514627456665 epoch 6 batch 15 
total loss: 0.6691784858703613 epoch 6 batch 20 
total loss: 0.8158406019210815 epoch 6 batch 25 
total loss: 0.7383745908737183 epoch 6 batch 30 
total loss: 0.7447091341018677 epoch 6 batch 35 
total loss: 0.7862703800201416 epoch 6 batch 40 
total loss: 0.8322451710700989 epoch 6 batch 45 
total loss: 0.8715308904647827 epoch 6 batch 50 
total loss: 0.706369161605835 epoch 6 batch 55 
total loss: 0.7995638251304626 epoch 6 batch 60 
total loss: 0.8098785281181335 epoch 6 batch 65 
total loss: 0.6516961455345154 epoch 6 batch 70 
total loss: 0.7424792647361755 epoch 6 batch 75 
total loss: 0.7396417856216431 epoch 6 batch 80 
total loss: 0.7362191677093506 epoch 6 batch 85 
total loss: 0.9558976292610168 epoch 6 batch 90 
total loss: 0.8189946413040161 epoch 6 batch 95 
total loss: 0.7554519176483154 epoch 6 batch 100 
total loss: 0.772563099861145 epoch 6 batch 105 
total loss: 0.8337545394897461 epoch 6 batch 110 
total loss: 0.7600473761558533 epoch 6 batch 115 
total loss: 0.7708126902580261 epoch 6 batch 120 
total loss: 0.6998305320739746 epoch 6 batch 125 
total loss: 0.5774018168449402 epoch 7 batch 5 
total loss: 0.6558392643928528 epoch 7 batch 10 
total loss: 0.5383725762367249 epoch 7 batch 15 
total loss: 0.6437508463859558 epoch 7 batch 20 
total loss: 0.594805121421814 epoch 7 batch 25 
total loss: 0.5795590281486511 epoch 7 batch 30 
total loss: 0.7567883729934692 epoch 7 batch 35 
total loss: 0.5663882493972778 epoch 7 batch 40 
total loss: 0.6014893054962158 epoch 7 batch 45 
total loss: 0.5960389971733093 epoch 7 batch 50 
total loss: 0.633935809135437 epoch 7 batch 55 
total loss: 0.6122901439666748 epoch 7 batch 60 
total loss: 0.6862307786941528 epoch 7 batch 65 
total loss: 0.7035883665084839 epoch 7 batch 70 
total loss: 0.7250910997390747 epoch 7 batch 75 
total loss: 0.6099406480789185 epoch 7 batch 80 
total loss: 0.5820813179016113 epoch 7 batch 85 
total loss: 0.5596920251846313 epoch 7 batch 90 
total loss: 0.6520225405693054 epoch 7 batch 95 
total loss: 0.6929486393928528 epoch 7 batch 100 
total loss: 0.6704176664352417 epoch 7 batch 105 
total loss: 0.6779621839523315 epoch 7 batch 110 
total loss: 0.6607205271720886 epoch 7 batch 115 
total loss: 0.5835480690002441 epoch 7 batch 120 
total loss: 0.6930114030838013 epoch 7 batch 125 
total loss: 0.4390392303466797 epoch 8 batch 5 
total loss: 0.502900242805481 epoch 8 batch 10 
total loss: 0.44953665137290955 epoch 8 batch 15 
total loss: 0.5513165593147278 epoch 8 batch 20 
total loss: 0.5228049159049988 epoch 8 batch 25 
total loss: 0.48368993401527405 epoch 8 batch 30 
total loss: 0.4797203540802002 epoch 8 batch 35 
total loss: 0.5822355151176453 epoch 8 batch 40 
total loss: 0.4232334494590759 epoch 8 batch 45 
total loss: 0.5870698094367981 epoch 8 batch 50 
total loss: 0.48263269662857056 epoch 8 batch 55 
total loss: 0.44463014602661133 epoch 8 batch 60 
total loss: 0.49221715331077576 epoch 8 batch 65 
total loss: 0.5247334837913513 epoch 8 batch 70 
total loss: 0.6095311045646667 epoch 8 batch 75 
total loss: 0.5857402086257935 epoch 8 batch 80 
total loss: 0.46884751319885254 epoch 8 batch 85 
total loss: 0.5228506326675415 epoch 8 batch 90 
total loss: 0.46329981088638306 epoch 8 batch 95 
total loss: 0.5708974003791809 epoch 8 batch 100 
total loss: 0.5332533121109009 epoch 8 batch 105 
total loss: 0.532862663269043 epoch 8 batch 110 
total loss: 0.5066767334938049 epoch 8 batch 115 
total loss: 0.5123209357261658 epoch 8 batch 120 
total loss: 0.49092260003089905 epoch 8 batch 125 
total loss: 0.3689466118812561 epoch 9 batch 5 
total loss: 0.3250238299369812 epoch 9 batch 10 
total loss: 0.44425806403160095 epoch 9 batch 15 
total loss: 0.3422010838985443 epoch 9 batch 20 
total loss: 0.3302001357078552 epoch 9 batch 25 
total loss: 0.3376121520996094 epoch 9 batch 30 
total loss: 0.43184441328048706 epoch 9 batch 35 
total loss: 0.40924182534217834 epoch 9 batch 40 
total loss: 0.38222527503967285 epoch 9 batch 45 
total loss: 0.4478159546852112 epoch 9 batch 50 
total loss: 0.4593771994113922 epoch 9 batch 55 
total loss: 0.3862895369529724 epoch 9 batch 60 
total loss: 0.40882641077041626 epoch 9 batch 65 
total loss: 0.4312051236629486 epoch 9 batch 70 
total loss: 0.41449132561683655 epoch 9 batch 75 
total loss: 0.45340195298194885 epoch 9 batch 80 
total loss: 0.4121376574039459 epoch 9 batch 85 
total loss: 0.5007123947143555 epoch 9 batch 90 
total loss: 0.4919680655002594 epoch 9 batch 95 
total loss: 0.4845644533634186 epoch 9 batch 100 
total loss: 0.5462281107902527 epoch 9 batch 105 
total loss: 0.3803269863128662 epoch 9 batch 110 
total loss: 0.4410593509674072 epoch 9 batch 115 
total loss: 0.44156259298324585 epoch 9 batch 120 
total loss: 0.48795849084854126 epoch 9 batch 125 
total loss: 0.2677317261695862 epoch 10 batch 5 
total loss: 0.2849716544151306 epoch 10 batch 10 
total loss: 0.2650394141674042 epoch 10 batch 15 
total loss: 0.3369045853614807 epoch 10 batch 20 
total loss: 0.27701759338378906 epoch 10 batch 25 
total loss: 0.2801435589790344 epoch 10 batch 30 
total loss: 0.2140922248363495 epoch 10 batch 35 
total loss: 0.3308884799480438 epoch 10 batch 40 
total loss: 0.3573286831378937 epoch 10 batch 45 
total loss: 0.3585323691368103 epoch 10 batch 50 
total loss: 0.3311135172843933 epoch 10 batch 55 
total loss: 0.4792550206184387 epoch 10 batch 60 
total loss: 0.3666520416736603 epoch 10 batch 65 
total loss: 0.3469542860984802 epoch 10 batch 70 
total loss: 0.39160677790641785 epoch 10 batch 75 
total loss: 0.375261127948761 epoch 10 batch 80 
total loss: 0.34272241592407227 epoch 10 batch 85 
total loss: 0.43078818917274475 epoch 10 batch 90 
total loss: 0.2668665647506714 epoch 10 batch 95 
total loss: 0.37373679876327515 epoch 10 batch 100 
total loss: 0.3685140311717987 epoch 10 batch 105 
total loss: 0.3151918351650238 epoch 10 batch 110 
total loss: 0.34442174434661865 epoch 10 batch 115 
total loss: 0.4334893822669983 epoch 10 batch 120 
total loss: 0.3609371781349182 epoch 10 batch 125 
total loss: 0.22984731197357178 epoch 11 batch 5 
total loss: 0.2418794184923172 epoch 11 batch 10 
total loss: 0.2832423150539398 epoch 11 batch 15 
total loss: 0.24624614417552948 epoch 11 batch 20 
total loss: 0.2362026870250702 epoch 11 batch 25 
total loss: 0.24753396213054657 epoch 11 batch 30 
total loss: 0.30054670572280884 epoch 11 batch 35 
total loss: 0.2899046242237091 epoch 11 batch 40 
total loss: 0.2667545676231384 epoch 11 batch 45 
total loss: 0.2646659314632416 epoch 11 batch 50 
total loss: 0.2119644582271576 epoch 11 batch 55 
total loss: 0.2534087002277374 epoch 11 batch 60 
total loss: 0.2593185305595398 epoch 11 batch 65 
total loss: 0.3010985553264618 epoch 11 batch 70 
total loss: 0.30156993865966797 epoch 11 batch 75 
total loss: 0.31427207589149475 epoch 11 batch 80 
total loss: 0.30148079991340637 epoch 11 batch 85 
total loss: 0.31685349345207214 epoch 11 batch 90 
total loss: 0.2858041822910309 epoch 11 batch 95 
total loss: 0.23358872532844543 epoch 11 batch 100 
total loss: 0.3077571392059326 epoch 11 batch 105 
total loss: 0.2969404458999634 epoch 11 batch 110 
total loss: 0.36080026626586914 epoch 11 batch 115 
total loss: 0.2823699116706848 epoch 11 batch 120 
total loss: 0.28497424721717834 epoch 11 batch 125 
total loss: 0.2739666998386383 epoch 12 batch 5 
total loss: 0.247772216796875 epoch 12 batch 10 
total loss: 0.21159425377845764 epoch 12 batch 15 
total loss: 0.24877581000328064 epoch 12 batch 20 
total loss: 0.23003773391246796 epoch 12 batch 25 
total loss: 0.22304224967956543 epoch 12 batch 30 
total loss: 0.2357630431652069 epoch 12 batch 35 
total loss: 0.24652035534381866 epoch 12 batch 40 
total loss: 0.24459195137023926 epoch 12 batch 45 
total loss: 0.2198447287082672 epoch 12 batch 50 
total loss: 0.22670245170593262 epoch 12 batch 55 
total loss: 0.2391890287399292 epoch 12 batch 60 
total loss: 0.2453366070985794 epoch 12 batch 65 
total loss: 0.21846142411231995 epoch 12 batch 70 
total loss: 0.25742220878601074 epoch 12 batch 75 
total loss: 0.2598118185997009 epoch 12 batch 80 
total loss: 0.2885677218437195 epoch 12 batch 85 
total loss: 0.32734522223472595 epoch 12 batch 90 
total loss: 0.3083980083465576 epoch 12 batch 95 
total loss: 0.3234527111053467 epoch 12 batch 100 
total loss: 0.29528990387916565 epoch 12 batch 105 
total loss: 0.27330103516578674 epoch 12 batch 110 
total loss: 0.2824668288230896 epoch 12 batch 115 
total loss: 0.26833924651145935 epoch 12 batch 120 
total loss: 0.3090164065361023 epoch 12 batch 125 
total loss: 0.18143436312675476 epoch 13 batch 5 
total loss: 0.24107468128204346 epoch 13 batch 10 
total loss: 0.1723310351371765 epoch 13 batch 15 
total loss: 0.2374371737241745 epoch 13 batch 20 
total loss: 0.18838974833488464 epoch 13 batch 25 
total loss: 0.1868618130683899 epoch 13 batch 30 
total loss: 0.2468196451663971 epoch 13 batch 35 
total loss: 0.18816381692886353 epoch 13 batch 40 
total loss: 0.2015218436717987 epoch 13 batch 45 
total loss: 0.17972926795482635 epoch 13 batch 50 
total loss: 0.19488045573234558 epoch 13 batch 55 
total loss: 0.179433211684227 epoch 13 batch 60 
total loss: 0.18720710277557373 epoch 13 batch 65 
total loss: 0.26200735569000244 epoch 13 batch 70 
total loss: 0.2021588832139969 epoch 13 batch 75 
total loss: 0.2547597587108612 epoch 13 batch 80 
total loss: 0.2753807604312897 epoch 13 batch 85 
total loss: 0.27378445863723755 epoch 13 batch 90 
total loss: 0.24202470481395721 epoch 13 batch 95 
total loss: 0.22158583998680115 epoch 13 batch 100 
total loss: 0.22244706749916077 epoch 13 batch 105 
total loss: 0.23681640625 epoch 13 batch 110 
total loss: 0.2990795373916626 epoch 13 batch 115 
total loss: 0.2641446888446808 epoch 13 batch 120 
total loss: 0.23204472661018372 epoch 13 batch 125 
total loss: 0.1728627234697342 epoch 14 batch 5 
total loss: 0.17328783869743347 epoch 14 batch 10 
total loss: 0.20071764290332794 epoch 14 batch 15 
total loss: 0.15985815227031708 epoch 14 batch 20 
total loss: 0.16585272550582886 epoch 14 batch 25 
total loss: 0.17702646553516388 epoch 14 batch 30 
total loss: 0.19213533401489258 epoch 14 batch 35 
total loss: 0.1678582727909088 epoch 14 batch 40 
total loss: 0.17316825687885284 epoch 14 batch 45 
total loss: 0.18272583186626434 epoch 14 batch 50 
total loss: 0.2643834352493286 epoch 14 batch 55 
total loss: 0.1914786398410797 epoch 14 batch 60 
total loss: 0.24789147078990936 epoch 14 batch 65 
total loss: 0.20449848473072052 epoch 14 batch 70 
total loss: 0.20783527195453644 epoch 14 batch 75 
total loss: 0.20063571631908417 epoch 14 batch 80 
total loss: 0.22110366821289062 epoch 14 batch 85 
total loss: 0.27967819571495056 epoch 14 batch 90 
total loss: 0.21627402305603027 epoch 14 batch 95 
total loss: 0.2716841697692871 epoch 14 batch 100 
total loss: 0.26125216484069824 epoch 14 batch 105 
total loss: 0.28036823868751526 epoch 14 batch 110 
total loss: 0.2875978350639343 epoch 14 batch 115 
total loss: 0.24142596125602722 epoch 14 batch 120 
total loss: 0.2443583458662033 epoch 14 batch 125 
total loss: 0.12856344878673553 epoch 15 batch 5 
total loss: 0.19890321791172028 epoch 15 batch 10 
total loss: 0.21472203731536865 epoch 15 batch 15 
total loss: 0.1831301748752594 epoch 15 batch 20 
total loss: 0.18663254380226135 epoch 15 batch 25 
total loss: 0.1720810979604721 epoch 15 batch 30 
total loss: 0.1868405044078827 epoch 15 batch 35 
total loss: 0.22495588660240173 epoch 15 batch 40 
total loss: 0.20775218307971954 epoch 15 batch 45 
total loss: 0.1730271875858307 epoch 15 batch 50 
total loss: 0.2216344177722931 epoch 15 batch 55 
total loss: 0.21534466743469238 epoch 15 batch 60 
total loss: 0.16373391449451447 epoch 15 batch 65 
total loss: 0.20450153946876526 epoch 15 batch 70 
total loss: 0.22405973076820374 epoch 15 batch 75 
total loss: 0.21495337784290314 epoch 15 batch 80 
total loss: 0.19726605713367462 epoch 15 batch 85 
total loss: 0.1876198798418045 epoch 15 batch 90 
total loss: 0.19518446922302246 epoch 15 batch 95 
total loss: 0.2388114184141159 epoch 15 batch 100 
total loss: 0.19776995480060577 epoch 15 batch 105 
total loss: 0.20737934112548828 epoch 15 batch 110 
total loss: 0.20728227496147156 epoch 15 batch 115 
total loss: 0.19053156673908234 epoch 15 batch 120 
total loss: 0.17176872491836548 epoch 15 batch 125 

Evaluation


In [0]:
#In this section we evaluate our model on a raw_input converted to german, for this the entire sentence has to be passed
#through the length of the model, for this we use greedsampler to run through the decoder
#and the final embedding matrix trained on the data is used to generate embeddings
input_raw='how are you'

# We have a transcript file containing English-German pairs
# Preprocess X
input_lines = ['<start> '+input_raw+'']
input_sequences = [[en_tokenizer.word_index[w] for w in line.split(' ')] for line in input_lines]
input_sequences = tf.keras.preprocessing.sequence.pad_sequences(input_sequences,
                                                                maxlen=Tx, padding='post')
inp = tf.convert_to_tensor(input_sequences)
#print(inp.shape)
inference_batch_size = input_sequences.shape[0]
encoder_initial_cell_state = [tf.zeros((inference_batch_size, rnn_units)),
                              tf.zeros((inference_batch_size, rnn_units))]
encoder_emb_inp = encoderNetwork.encoder_embedding(inp)
a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp,
                                                initial_state =encoder_initial_cell_state)
print('a_tx :',a_tx.shape)
print('c_tx :', c_tx.shape)

start_tokens = tf.fill([inference_batch_size],ge_tokenizer.word_index['<start>'])

end_token = ge_tokenizer.word_index['<end>']

greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()

decoder_input = tf.expand_dims([ge_tokenizer.word_index['<start>']]* inference_batch_size,1)
decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)

decoder_instance = tfa.seq2seq.BasicDecoder(cell = decoderNetwork.rnn_cell, sampler = greedy_sampler,
                                            output_layer=decoderNetwork.dense_layer)
decoderNetwork.attention_mechanism.setup_memory(a)
#pass [ last step activations , encoder memory_state ] as input to decoder for LSTM
print("decoder_initial_state = [a_tx, c_tx] :",np.array([a_tx, c_tx]).shape)
decoder_initial_state = decoderNetwork.build_decoder_initial_state(inference_batch_size,
                                                                   encoder_state=[a_tx, c_tx],
                                                                   Dtype=tf.float32)
print("\nCompared to simple encoder-decoder without attention, the decoder_initial_state \
 is an AttentionWrapperState object containing s_prev tensors and context and alignment vector \n ")
print("decoder initial state shape :",np.array(decoder_initial_state).shape)
print("decoder_initial_state tensor \n", decoder_initial_state)

# Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.
# One heuristic is to decode up to two times the source sentence lengths.
maximum_iterations = tf.round(tf.reduce_max(Tx) * 2)

#initialize inference decoder
decoder_embedding_matrix = decoderNetwork.decoder_embedding.variables[0] 
(first_finished, first_inputs,first_state) = decoder_instance.initialize(decoder_embedding_matrix,
                             start_tokens = start_tokens,
                             end_token=end_token,
                             initial_state = decoder_initial_state)
#print( first_finished.shape)
print("\nfirst_inputs returns the same decoder_input i.e. embedding of  <start> :",first_inputs.shape)
print("start_index_emb_avg ", tf.reduce_sum(tf.reduce_mean(first_inputs, axis=0))) # mean along the batch

inputs = first_inputs
state = first_state  
predictions = np.empty((inference_batch_size,0), dtype = np.int32)                                                                             
for j in range(maximum_iterations):
    outputs, next_state, next_inputs, finished = decoder_instance.step(j,inputs,state)
    inputs = next_inputs
    state = next_state
    outputs = np.expand_dims(outputs.sample_id,axis = -1)
    predictions = np.append(predictions, outputs, axis = -1)


a_tx : (1, 1024)
c_tx : (1, 1024)
decoder_initial_state = [a_tx, c_tx] : (2, 1, 1024)

Compared to simple encoder-decoder without attention, the decoder_initial_state  is an AttentionWrapperState object containing s_prev tensors and context and alignment vector 
 
decoder initial state shape : (6,)
decoder_initial_state tensor 
 AttentionWrapperState(cell_state=[<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=
array([[ 0.00032512,  0.00170071,  0.06353987, ..., -0.01063914,
        -0.01768327, -0.08082021]], dtype=float32)>, <tf.Tensor: shape=(1, 1024), dtype=float32, numpy=
array([[ 0.00110459,  0.00727613,  0.17626816, ..., -0.02890627,
        -0.06944194, -0.15541168]], dtype=float32)>], attention=<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, time=<tf.Tensor: shape=(), dtype=int32, numpy=0>, alignments=<tf.Tensor: shape=(1, 7), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>, alignment_history=(), attention_state=<tf.Tensor: shape=(1, 7), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>)

first_inputs returns the same decoder_input i.e. embedding of  <start> : (1, 256)
start_index_emb_avg  tf.Tensor(0.108049706, shape=(), dtype=float32)

Final Translation


In [0]:
#prediction based on our sentence earlier
print("English Sentence:")
print(input_raw)
print("\nGerman Translation:")
for i in range(len(predictions)):
    line = predictions[i,:]
    seq = list(itertools.takewhile( lambda index: index !=2, line))
    print(" ".join( [ge_tokenizer.index_word[w] for w in seq]))


English Sentence:
how are you

German Translation:
wie arrogant

The accuracy can be improved by implementing:

  • Beam Search or Lexicon Search
  • Bi-directional encoder-decoder model