Quora question pairs: training a model with attention

Import packages


In [1]:
%matplotlib inline
from __future__ import print_function
import numpy as np
import pandas as pd
import datetime, time, json
from keras.models import Model
from keras.layers import Input, Bidirectional, LSTM, dot, Flatten, Dense, Reshape, add, Dropout, BatchNormalization
from keras.layers.embeddings import Embedding
from keras.regularizers import l2
from keras.callbacks import Callback, ModelCheckpoint
from keras import backend as K
from sklearn.model_selection import train_test_split


Using TensorFlow backend.

Initialize global variables


In [8]:
Q1_TRAINING_DATA_FILE = 'q1_train.npy'
Q2_TRAINING_DATA_FILE = 'q2_train.npy'
LABEL_TRAINING_DATA_FILE = 'label_train.npy'
WORD_EMBEDDING_MATRIX_FILE = 'word_embedding_matrix.npy'
NB_WORDS_DATA_FILE = 'nb_words.json'
MODEL_WEIGHTS_FILE = 'question_pairs_weights.h5'
MAX_SEQUENCE_LENGTH = 25
WORD_EMBEDDING_DIM = 300
SENT_EMBEDDING_DIM = 128
VALIDATION_SPLIT = 0.1
TEST_SPLIT = 0.1
RNG_SEED = 13371447
NB_EPOCHS = 25
DROPOUT = 0.2
BATCH_SIZE = 516

Load the dataset, embedding matrix and word count


In [3]:
q1_data = np.load(open(Q1_TRAINING_DATA_FILE, 'rb'))
q2_data = np.load(open(Q2_TRAINING_DATA_FILE, 'rb'))
labels = np.load(open(LABEL_TRAINING_DATA_FILE, 'rb'))
word_embedding_matrix = np.load(open(WORD_EMBEDDING_MATRIX_FILE, 'rb'))
with open(NB_WORDS_DATA_FILE, 'r') as f:
    nb_words = json.load(f)['nb_words']

Partition the dataset into train and test sets


In [4]:
X = np.stack((q1_data, q2_data), axis=1)
y = labels
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SPLIT, random_state=RNG_SEED)
Q1_train = X_train[:,0]
Q2_train = X_train[:,1]
Q1_test = X_test[:,0]
Q2_test = X_test[:,1]

Define the model


In [9]:
question1 = Input(shape=(MAX_SEQUENCE_LENGTH,))
question2 = Input(shape=(MAX_SEQUENCE_LENGTH,))

q1 = Embedding(nb_words + 1, 
                 WORD_EMBEDDING_DIM, 
                 weights=[word_embedding_matrix], 
                 input_length=MAX_SEQUENCE_LENGTH, 
                 trainable=False)(question1)
q1 = Bidirectional(LSTM(SENT_EMBEDDING_DIM, return_sequences=True), merge_mode="sum")(q1)

q2 = Embedding(nb_words + 1, 
                 WORD_EMBEDDING_DIM, 
                 weights=[word_embedding_matrix], 
                 input_length=MAX_SEQUENCE_LENGTH, 
                 trainable=False)(question2)
q2 = Bidirectional(LSTM(SENT_EMBEDDING_DIM, return_sequences=True), merge_mode="sum")(q2)

attention = dot([q1,q2], [1,1])
attention = Flatten()(attention)
attention = Dense((MAX_SEQUENCE_LENGTH*SENT_EMBEDDING_DIM))(attention)
attention = Reshape((MAX_SEQUENCE_LENGTH, SENT_EMBEDDING_DIM))(attention)

merged = add([q1,attention])
merged = Flatten()(merged)
merged = Dense(200, activation='relu')(merged)
merged = Dropout(DROPOUT)(merged)
merged = BatchNormalization()(merged)
merged = Dense(200, activation='relu')(merged)
merged = Dropout(DROPOUT)(merged)
merged = BatchNormalization()(merged)
merged = Dense(200, activation='relu')(merged)
merged = Dropout(DROPOUT)(merged)
merged = BatchNormalization()(merged)
merged = Dense(200, activation='relu')(merged)
merged = Dropout(DROPOUT)(merged)
merged = BatchNormalization()(merged)

is_duplicate = Dense(1, activation='sigmoid')(merged)

model = Model(inputs=[question1,question2], outputs=is_duplicate)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

In [10]:
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_3 (InputLayer)             (None, 25)            0                                            
____________________________________________________________________________________________________
input_4 (InputLayer)             (None, 25)            0                                            
____________________________________________________________________________________________________
embedding_3 (Embedding)          (None, 25, 300)       28679100    input_3[0][0]                    
____________________________________________________________________________________________________
embedding_4 (Embedding)          (None, 25, 300)       28679100    input_4[0][0]                    
____________________________________________________________________________________________________
bidirectional_3 (Bidirectional)  (None, 25, 128)       439296      embedding_3[0][0]                
____________________________________________________________________________________________________
bidirectional_4 (Bidirectional)  (None, 25, 128)       439296      embedding_4[0][0]                
____________________________________________________________________________________________________
dot_2 (Dot)                      (None, 128, 128)      0           bidirectional_3[0][0]            
                                                                   bidirectional_4[0][0]            
____________________________________________________________________________________________________
flatten_3 (Flatten)              (None, 16384)         0           dot_2[0][0]                      
____________________________________________________________________________________________________
dense_7 (Dense)                  (None, 3200)          52432000    flatten_3[0][0]                  
____________________________________________________________________________________________________
reshape_2 (Reshape)              (None, 25, 128)       0           dense_7[0][0]                    
____________________________________________________________________________________________________
add_2 (Add)                      (None, 25, 128)       0           bidirectional_3[0][0]            
                                                                   reshape_2[0][0]                  
____________________________________________________________________________________________________
flatten_4 (Flatten)              (None, 3200)          0           add_2[0][0]                      
____________________________________________________________________________________________________
dense_8 (Dense)                  (None, 200)           640200      flatten_4[0][0]                  
____________________________________________________________________________________________________
dropout_5 (Dropout)              (None, 200)           0           dense_8[0][0]                    
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 200)           800         dropout_5[0][0]                  
____________________________________________________________________________________________________
dense_9 (Dense)                  (None, 200)           40200       batch_normalization_5[0][0]      
____________________________________________________________________________________________________
dropout_6 (Dropout)              (None, 200)           0           dense_9[0][0]                    
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 200)           800         dropout_6[0][0]                  
____________________________________________________________________________________________________
dense_10 (Dense)                 (None, 200)           40200       batch_normalization_6[0][0]      
____________________________________________________________________________________________________
dropout_7 (Dropout)              (None, 200)           0           dense_10[0][0]                   
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 200)           800         dropout_7[0][0]                  
____________________________________________________________________________________________________
dense_11 (Dense)                 (None, 200)           40200       batch_normalization_7[0][0]      
____________________________________________________________________________________________________
dropout_8 (Dropout)              (None, 200)           0           dense_11[0][0]                   
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 200)           800         dropout_8[0][0]                  
____________________________________________________________________________________________________
dense_12 (Dense)                 (None, 1)             201         batch_normalization_8[0][0]      
====================================================================================================
Total params: 111,432,993
Trainable params: 54,073,193
Non-trainable params: 57,359,800
____________________________________________________________________________________________________

Train the model, checkpointing weights with best validation accuracy


In [11]:
print("Starting training at", datetime.datetime.now())
t0 = time.time()
callbacks = [ModelCheckpoint(MODEL_WEIGHTS_FILE, monitor='val_acc', save_best_only=True)]
history = model.fit([Q1_train, Q2_train],
                    y_train,
                    epochs=NB_EPOCHS,
                    validation_split=VALIDATION_SPLIT,
                    verbose=2,
                    batch_size=BATCH_SIZE,
                    callbacks=callbacks)
t1 = time.time()
print("Training ended at", datetime.datetime.now())
print("Minutes elapsed: %f" % ((t1 - t0) / 60.))


Starting training at 2017-06-01 23:44:48.287525
Train on 327474 samples, validate on 36387 samples
Epoch 1/25
242s - loss: 0.5341 - acc: 0.7256 - val_loss: 0.4826 - val_acc: 0.7525
Epoch 2/25
239s - loss: 0.4395 - acc: 0.7847 - val_loss: 0.4441 - val_acc: 0.7805
Epoch 3/25
239s - loss: 0.3921 - acc: 0.8124 - val_loss: 0.4037 - val_acc: 0.8045
Epoch 4/25
239s - loss: 0.3501 - acc: 0.8375 - val_loss: 0.4056 - val_acc: 0.8121
Epoch 5/25
238s - loss: 0.3093 - acc: 0.8595 - val_loss: 0.4167 - val_acc: 0.8150
Epoch 6/25
239s - loss: 0.2681 - acc: 0.8814 - val_loss: 0.4222 - val_acc: 0.8171
Epoch 7/25
239s - loss: 0.2270 - acc: 0.9023 - val_loss: 0.4548 - val_acc: 0.8181
Epoch 8/25
239s - loss: 0.1894 - acc: 0.9203 - val_loss: 0.4795 - val_acc: 0.8230
Epoch 9/25
239s - loss: 0.1555 - acc: 0.9365 - val_loss: 0.5398 - val_acc: 0.8231
Epoch 10/25
229s - loss: 0.1291 - acc: 0.9486 - val_loss: 0.5389 - val_acc: 0.8200
Epoch 11/25
229s - loss: 0.1077 - acc: 0.9575 - val_loss: 0.6279 - val_acc: 0.8185
Epoch 12/25
229s - loss: 0.0925 - acc: 0.9639 - val_loss: 0.6428 - val_acc: 0.8208
Epoch 13/25
229s - loss: 0.0791 - acc: 0.9697 - val_loss: 0.6564 - val_acc: 0.8225
Epoch 14/25
229s - loss: 0.0677 - acc: 0.9741 - val_loss: 0.6822 - val_acc: 0.8222
Epoch 15/25
229s - loss: 0.0605 - acc: 0.9770 - val_loss: 0.7541 - val_acc: 0.8209
Epoch 16/25
238s - loss: 0.0541 - acc: 0.9798 - val_loss: 0.7523 - val_acc: 0.8241
Epoch 17/25
228s - loss: 0.0485 - acc: 0.9819 - val_loss: 0.7647 - val_acc: 0.8232
Epoch 18/25
229s - loss: 0.0453 - acc: 0.9835 - val_loss: 0.8060 - val_acc: 0.8239
Epoch 19/25
229s - loss: 0.0436 - acc: 0.9843 - val_loss: 0.7953 - val_acc: 0.8223
Epoch 20/25
238s - loss: 0.0390 - acc: 0.9858 - val_loss: 0.8509 - val_acc: 0.8248
Epoch 21/25
238s - loss: 0.0371 - acc: 0.9865 - val_loss: 0.7879 - val_acc: 0.8263
Epoch 22/25
229s - loss: 0.0349 - acc: 0.9874 - val_loss: 0.8984 - val_acc: 0.8212
Epoch 23/25
229s - loss: 0.0330 - acc: 0.9883 - val_loss: 0.8845 - val_acc: 0.8186
Epoch 24/25
229s - loss: 0.0307 - acc: 0.9891 - val_loss: 0.8727 - val_acc: 0.8263
Epoch 25/25
229s - loss: 0.0291 - acc: 0.9896 - val_loss: 0.8777 - val_acc: 0.8220
Training ended at 2017-06-02 01:22:23.716517
Minutes elapsed: 97.590475

Plot training and validation accuracy


In [12]:
acc = pd.DataFrame({'epoch': [ i + 1 for i in history.epoch ],
                    'training': history.history['acc'],
                    'validation': history.history['val_acc']})
ax = acc.iloc[:,:].plot(x='epoch', figsize={5,8}, grid=True)
ax.set_ylabel("accuracy")
ax.set_ylim([0.0,1.0]);



In [13]:
max_val_acc, idx = max((val, idx) for (idx, val) in enumerate(history.history['val_acc']))
print('Maximum accuracy at epoch', '{:d}'.format(idx+1), '=', '{:.4f}'.format(max_val_acc))


Maximum accuracy at epoch 21 = 0.8263

Evaluate the model with best validation accuracy on the test partition


In [14]:
model.load_weights(MODEL_WEIGHTS_FILE)
loss, accuracy = model.evaluate([Q1_test, Q2_test], y_test, verbose=0)
print('loss = {0:.4f}, accuracy = {1:.4f}'.format(loss, accuracy))


loss = 0.7956, accuracy = 0.8243

In [ ]: