Using pre-trained word embeddings in a Keras model

Based on https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html


In [23]:
from __future__ import print_function


import os
import sys
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.layers import Dense, Input, GlobalMaxPooling1D
from keras.layers import Conv1D, MaxPooling1D, Embedding, Flatten
from keras.models import Model

In [11]:
from sklearn.datasets import fetch_20newsgroups

data_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'),
                                shuffle=True, random_state=42)

data_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'),
                               shuffle=True, random_state=42)

In [12]:
texts = data_train.data
labels = data_train.target
labels_index = {}
for i,l in enumerate(data_train.target_names):
    labels_index[i] = l
    
labels_index


Out[12]:
{0: 'alt.atheism',
 1: 'comp.graphics',
 2: 'comp.os.ms-windows.misc',
 3: 'comp.sys.ibm.pc.hardware',
 4: 'comp.sys.mac.hardware',
 5: 'comp.windows.x',
 6: 'misc.forsale',
 7: 'rec.autos',
 8: 'rec.motorcycles',
 9: 'rec.sport.baseball',
 10: 'rec.sport.hockey',
 11: 'sci.crypt',
 12: 'sci.electronics',
 13: 'sci.med',
 14: 'sci.space',
 15: 'soc.religion.christian',
 16: 'talk.politics.guns',
 17: 'talk.politics.mideast',
 18: 'talk.politics.misc',
 19: 'talk.religion.misc'}

In [13]:
data_train.data[0]


Out[13]:
'I was wondering if anyone out there could enlighten me on this car I saw\nthe other day. It was a 2-door sports car, looked to be from the late 60s/\nearly 70s. It was called a Bricklin. The doors were really small. In addition,\nthe front bumper was separate from the rest of the body. This is \nall I know. If anyone can tellme a model name, engine specs, years\nof production, where this car is made, history, or whatever info you\nhave on this funky looking car, please e-mail.'

In [17]:
MAX_SEQUENCE_LENGTH = 1000
MAX_NB_WORDS = 20000
EMBEDDING_DIM = 100
VALIDATION_SPLIT = 0.2


from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

tokenizer = Tokenizer(num_words=MAX_NB_WORDS)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))

data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)

labels = to_categorical(np.asarray(labels))
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)

# split the data into a training set and a validation set
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]
nb_validation_samples = int(VALIDATION_SPLIT * data.shape[0])

x_train = data[:-nb_validation_samples]
y_train = labels[:-nb_validation_samples]
x_val = data[-nb_validation_samples:]
y_val = labels[-nb_validation_samples:]


/home/jorge/anaconda3/envs/tf12/lib/python3.6/site-packages/keras/preprocessing/text.py:139: UserWarning: The `nb_words` argument in `Tokenizer` has been renamed `num_words`.
  warnings.warn('The `nb_words` argument in `Tokenizer` '
Found 105372 unique tokens.
Shape of data tensor: (11314, 1000)
Shape of label tensor: (11314, 20)

Preparing the Embedding layer


In [18]:
DATA_DIR = '/home/jorge/data/text'

embeddings_index = {}
f = open(os.path.join(DATA_DIR, 'glove.6B/glove.6B.100d.txt'))
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))


Found 400000 word vectors.

In [20]:
embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector

In [21]:
from keras.layers import Embedding

embedding_layer = Embedding(len(word_index) + 1,
                            EMBEDDING_DIM,
                            weights=[embedding_matrix],
                            input_length=MAX_SEQUENCE_LENGTH,
                            trainable=False)

Training a 1D convnet


In [34]:
from keras.optimizers import SGD


sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu')(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu')(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu')(x)
x = MaxPooling1D(35)(x)  # global max pooling
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
preds = Dense(len(labels_index), activation='softmax')(x)

model = Model(sequence_input, preds)
model.summary()

sgd_optimizer = SGD(lr=0.01, momentum=0.99, decay=0.001, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=sgd_optimizer,
              metrics=['acc'])

# happy learning!
model.fit(x_train, y_train, validation_data=(x_val, y_val),
          epochs=50, batch_size=128)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        (None, 1000)              0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 1000, 100)         10537300  
_________________________________________________________________
conv1d_30 (Conv1D)           (None, 996, 128)          64128     
_________________________________________________________________
max_pooling1d_30 (MaxPooling (None, 199, 128)          0         
_________________________________________________________________
conv1d_31 (Conv1D)           (None, 195, 128)          82048     
_________________________________________________________________
max_pooling1d_31 (MaxPooling (None, 39, 128)           0         
_________________________________________________________________
conv1d_32 (Conv1D)           (None, 35, 128)           82048     
_________________________________________________________________
max_pooling1d_32 (MaxPooling (None, 1, 128)            0         
_________________________________________________________________
flatten_10 (Flatten)         (None, 128)               0         
_________________________________________________________________
dense_19 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_20 (Dense)             (None, 20)                2580      
=================================================================
Total params: 10,784,616
Trainable params: 247,316
Non-trainable params: 10,537,300
_________________________________________________________________
Train on 9052 samples, validate on 2262 samples
Epoch 1/50
9052/9052 [==============================] - 1s - loss: 2.9026 - acc: 0.0949 - val_loss: 2.5609 - val_acc: 0.1667
Epoch 2/50
9052/9052 [==============================] - 1s - loss: 2.3365 - acc: 0.2154 - val_loss: 2.0781 - val_acc: 0.2754
Epoch 3/50
9052/9052 [==============================] - 1s - loss: 1.8558 - acc: 0.3452 - val_loss: 1.8000 - val_acc: 0.3660
Epoch 4/50
9052/9052 [==============================] - 1s - loss: 1.5351 - acc: 0.4651 - val_loss: 1.5959 - val_acc: 0.4664
Epoch 5/50
9052/9052 [==============================] - 1s - loss: 1.3099 - acc: 0.5461 - val_loss: 1.4959 - val_acc: 0.5093
Epoch 6/50
9052/9052 [==============================] - 1s - loss: 1.1301 - acc: 0.6075 - val_loss: 1.4695 - val_acc: 0.5332
Epoch 7/50
9052/9052 [==============================] - 1s - loss: 0.9741 - acc: 0.6607 - val_loss: 1.5017 - val_acc: 0.5442
Epoch 8/50
9052/9052 [==============================] - 1s - loss: 0.8302 - acc: 0.7077 - val_loss: 1.6032 - val_acc: 0.5526
Epoch 9/50
9052/9052 [==============================] - 1s - loss: 0.6903 - acc: 0.7586 - val_loss: 1.6759 - val_acc: 0.5592
Epoch 10/50
9052/9052 [==============================] - 1s - loss: 0.5701 - acc: 0.7990 - val_loss: 1.8551 - val_acc: 0.5597
Epoch 11/50
9052/9052 [==============================] - 1s - loss: 0.4754 - acc: 0.8312 - val_loss: 2.1731 - val_acc: 0.5539
Epoch 12/50
9052/9052 [==============================] - 1s - loss: 0.4150 - acc: 0.8541 - val_loss: 2.3961 - val_acc: 0.5615
Epoch 13/50
9052/9052 [==============================] - 1s - loss: 0.3623 - acc: 0.8814 - val_loss: 2.6773 - val_acc: 0.5588
Epoch 14/50
9052/9052 [==============================] - 1s - loss: 0.3372 - acc: 0.8900 - val_loss: 2.7061 - val_acc: 0.5623
Epoch 15/50
9052/9052 [==============================] - 1s - loss: 0.3177 - acc: 0.8956 - val_loss: 2.9758 - val_acc: 0.5513
Epoch 16/50
9052/9052 [==============================] - 1s - loss: 0.3358 - acc: 0.8883 - val_loss: 2.8447 - val_acc: 0.5504
Epoch 17/50
9052/9052 [==============================] - 1s - loss: 0.3498 - acc: 0.8886 - val_loss: 2.8347 - val_acc: 0.5544
Epoch 18/50
9052/9052 [==============================] - 1s - loss: 0.3396 - acc: 0.8911 - val_loss: 2.8171 - val_acc: 0.5513
Epoch 19/50
9052/9052 [==============================] - 1s - loss: 0.2860 - acc: 0.9102 - val_loss: 3.0459 - val_acc: 0.5539
Epoch 20/50
9052/9052 [==============================] - 1s - loss: 0.2638 - acc: 0.9143 - val_loss: 3.0292 - val_acc: 0.5522
Epoch 21/50
9052/9052 [==============================] - 1s - loss: 0.2467 - acc: 0.9223 - val_loss: 3.3793 - val_acc: 0.5517
Epoch 22/50
9052/9052 [==============================] - 1s - loss: 0.2407 - acc: 0.9222 - val_loss: 3.3935 - val_acc: 0.5579
Epoch 23/50
9052/9052 [==============================] - 1s - loss: 0.2421 - acc: 0.9222 - val_loss: 3.4062 - val_acc: 0.5535
Epoch 24/50
9052/9052 [==============================] - 1s - loss: 0.2362 - acc: 0.9233 - val_loss: 3.3569 - val_acc: 0.5447
Epoch 25/50
9052/9052 [==============================] - 1s - loss: 0.2225 - acc: 0.9319 - val_loss: 3.3845 - val_acc: 0.5681
Epoch 26/50
9052/9052 [==============================] - 1s - loss: 0.2136 - acc: 0.9322 - val_loss: 3.5193 - val_acc: 0.5654
Epoch 27/50
9052/9052 [==============================] - 1s - loss: 0.1983 - acc: 0.9371 - val_loss: 3.5416 - val_acc: 0.5522
Epoch 28/50
9052/9052 [==============================] - 1s - loss: 0.1846 - acc: 0.9419 - val_loss: 3.3554 - val_acc: 0.5650
Epoch 29/50
9052/9052 [==============================] - 1s - loss: 0.1745 - acc: 0.9474 - val_loss: 3.4919 - val_acc: 0.5707
Epoch 30/50
9052/9052 [==============================] - 1s - loss: 0.1530 - acc: 0.9518 - val_loss: 3.6556 - val_acc: 0.5663
Epoch 31/50
9052/9052 [==============================] - 1s - loss: 0.1479 - acc: 0.9535 - val_loss: 3.8163 - val_acc: 0.5601
Epoch 32/50
9052/9052 [==============================] - 1s - loss: 0.1449 - acc: 0.9534 - val_loss: 3.7456 - val_acc: 0.5694
Epoch 33/50
9052/9052 [==============================] - 1s - loss: 0.1364 - acc: 0.9570 - val_loss: 3.8401 - val_acc: 0.5663
Epoch 34/50
9052/9052 [==============================] - 1s - loss: 0.1183 - acc: 0.9643 - val_loss: 4.0214 - val_acc: 0.5619
Epoch 35/50
9052/9052 [==============================] - 1s - loss: 0.1239 - acc: 0.9611 - val_loss: 4.0244 - val_acc: 0.5645
Epoch 36/50
9052/9052 [==============================] - 1s - loss: 0.1255 - acc: 0.9618 - val_loss: 3.9801 - val_acc: 0.5619
Epoch 37/50
9052/9052 [==============================] - 1s - loss: 0.1155 - acc: 0.9638 - val_loss: 4.1707 - val_acc: 0.5690
Epoch 38/50
9052/9052 [==============================] - 1s - loss: 0.1112 - acc: 0.9656 - val_loss: 4.1792 - val_acc: 0.5650
Epoch 39/50
9052/9052 [==============================] - 1s - loss: 0.1135 - acc: 0.9644 - val_loss: 4.3631 - val_acc: 0.5592
Epoch 40/50
9052/9052 [==============================] - 1s - loss: 0.1187 - acc: 0.9637 - val_loss: 4.4191 - val_acc: 0.5637
Epoch 41/50
9052/9052 [==============================] - 1s - loss: 0.1188 - acc: 0.9631 - val_loss: 4.2304 - val_acc: 0.5729
Epoch 42/50
9052/9052 [==============================] - 1s - loss: 0.1344 - acc: 0.9599 - val_loss: 4.1120 - val_acc: 0.5707
Epoch 43/50
9052/9052 [==============================] - 1s - loss: 0.1258 - acc: 0.9637 - val_loss: 4.0578 - val_acc: 0.5659
Epoch 44/50
9052/9052 [==============================] - 1s - loss: 0.1142 - acc: 0.9634 - val_loss: 4.1595 - val_acc: 0.5672
Epoch 45/50
9052/9052 [==============================] - 1s - loss: 0.1151 - acc: 0.9620 - val_loss: 4.0692 - val_acc: 0.5681
Epoch 46/50
9052/9052 [==============================] - 1s - loss: 0.1137 - acc: 0.9660 - val_loss: 4.3480 - val_acc: 0.5712
Epoch 47/50
9052/9052 [==============================] - 1s - loss: 0.1182 - acc: 0.9643 - val_loss: 4.2420 - val_acc: 0.5712
Epoch 48/50
9052/9052 [==============================] - 1s - loss: 0.1099 - acc: 0.9661 - val_loss: 4.2597 - val_acc: 0.5796
Epoch 49/50
9052/9052 [==============================] - 1s - loss: 0.1044 - acc: 0.9674 - val_loss: 4.2330 - val_acc: 0.5756
Epoch 50/50
9052/9052 [==============================] - 1s - loss: 0.0964 - acc: 0.9704 - val_loss: 4.2597 - val_acc: 0.5769
Out[34]:
<keras.callbacks.History at 0x7f5577677dd8>

In [ ]: