In this tutorial, we will learn how to apply Doc2vec using gensim by recreating the results of Le and Mikolov 2014.
Early state-of-the-art document representations were based on the bag-of-words model, which represent input documents as a fixed-length vector. For example, borrowing from the Wikipedia article, the two documents
(1) John likes to watch movies. Mary likes movies too.
(2) John also likes to watch football games.
are used to construct a length 10 list of words
["John", "likes", "to", "watch", "movies", "Mary", "too", "also", "football", "games"]
so then we can represent the two documents as fixed length vectors whose elements are the frequencies of the corresponding words in our list
(1) [1, 2, 1, 1, 2, 1, 1, 0, 0, 0]
(2) [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]
Bag-of-words models are surprisingly effective but still lose information about word order. Bag of n-grams models consider word phrases of length n to represent documents as fixed-length vectors to capture local word order but suffer from data sparsity and high dimensionality.
Word2Vec
Word2Vec
is a more recent model that embeds words in a lower-dimensional vector space using a shallow neural network. The result is a set of word-vectors where vectors close together in vector space have similar meanings based on context, and word-vectors distant to each other have differing meanings. For example, strong
and powerful
would be close together and strong
and Paris
would be relatively far. There are two versions of this model based on skip-grams (SG) and continuous-bag-of-words (CBOW), both implemented by the gensim Word2Vec
class.
Word2Vec
- Skip-gram ModelThe skip-gram word2vec model, for example, takes in pairs (word1, word2) generated by moving a window across text data, and trains a 1-hidden-layer neural network based on the synthetic task of given an input word, giving us a predicted probability distribution of nearby words to the input. A virtual one-hot encoding of words goes through a 'projection layer' to the hidden layer; these projection weights are later interpreted as the word embeddings. So if the hidden layer has 300 neurons, this network will give us 300-dimensional word embeddings.
Word2Vec
- Continuous-bag-of-words ModelContinuous-bag-of-words Word2vec is very similar to the skip-gram model. It is also a 1-hidden-layer neural network. The synthetic training task now uses the average of multiple input context words, rather than a single word as in skip-gram, to predict the center word. Again, the projection weights that turn one-hot words into averageable vectors, of the same width as the hidden layer, are interpreted as the word embeddings.
But, Word2Vec doesn't yet get us fixed-size vectors for longer texts.
Doc2Vec
The straightforward approach of averaging each of a text's words' word-vectors creates a quick and crude document-vector that can often be useful. However, Le and Mikolov in 2014 introduced the Paragraph Vector, which usually outperforms such simple-averaging.
The basic idea is: act as if a document has another floating word-like vector, which contributes to all training predictions, and is updated like other word-vectors, but we will call it a doc-vector. Gensim's Doc2Vec
class implements this algorithm.
This is the Paragraph Vector model analogous to Word2Vec CBOW. The doc-vectors are obtained by training a neural network on the synthetic task of predicting a center word based an average of both context word-vectors and the full document's doc-vector.
This is the Paragraph Vector model analogous to Word2Vec SG. The doc-vectors are obtained by training a neural network on the synthetic task of predicting a target word just from the full document's doc-vector. (It is also common to combine this with skip-gram testing, using both the doc-vector and nearby word-vectors to predict a single target word, but only one at a time.)
The following python modules are dependencies for this tutorial:
pip install testfixtures
)pip install statsmodels
)Let's download the IMDB archive if it is not already downloaded (84 MB). This will be our text data for this tutorial.
The data can be found here: http://ai.stanford.edu/~amaas/data/sentiment/
This cell will only reattempt steps (such as downloading the compressed data) if their output isn't already present, so it is safe to re-run until it completes successfully.
In [1]:
%%time
import locale
import glob
import os.path
import requests
import tarfile
import sys
import codecs
from smart_open import smart_open
import re
dirname = 'aclImdb'
filename = 'aclImdb_v1.tar.gz'
locale.setlocale(locale.LC_ALL, 'C')
all_lines = []
if sys.version > '3':
control_chars = [chr(0x85)]
else:
control_chars = [unichr(0x85)]
# Convert text to lower-case and strip punctuation/symbols from words
def normalize_text(text):
norm_text = text.lower()
# Replace breaks with spaces
norm_text = norm_text.replace('<br />', ' ')
# Pad punctuation with spaces on both sides
norm_text = re.sub(r"([\.\",\(\)!\?;:])", " \\1 ", norm_text)
return norm_text
if not os.path.isfile('aclImdb/alldata-id.txt'):
if not os.path.isdir(dirname):
if not os.path.isfile(filename):
# Download IMDB archive
print("Downloading IMDB archive...")
url = u'http://ai.stanford.edu/~amaas/data/sentiment/' + filename
r = requests.get(url)
with smart_open(filename, 'wb') as f:
f.write(r.content)
# if error here, try `tar xfz aclImdb_v1.tar.gz` outside notebook, then re-run this cell
tar = tarfile.open(filename, mode='r')
tar.extractall()
tar.close()
else:
print("IMDB archive directory already available without download.")
# Collect & normalize test/train data
print("Cleaning up dataset...")
folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup']
for fol in folders:
temp = u''
newline = "\n".encode("utf-8")
output = fol.replace('/', '-') + '.txt'
# Is there a better pattern to use?
txt_files = glob.glob(os.path.join(dirname, fol, '*.txt'))
print(" %s: %i files" % (fol, len(txt_files)))
with smart_open(os.path.join(dirname, output), "wb") as n:
for i, txt in enumerate(txt_files):
with smart_open(txt, "rb") as t:
one_text = t.read().decode("utf-8")
for c in control_chars:
one_text = one_text.replace(c, ' ')
one_text = normalize_text(one_text)
all_lines.append(one_text)
n.write(one_text.encode("utf-8"))
n.write(newline)
# Save to disk for instant re-use on any future runs
with smart_open(os.path.join(dirname, 'alldata-id.txt'), 'wb') as f:
for idx, line in enumerate(all_lines):
num_line = u"_*{0} {1}\n".format(idx, line)
f.write(num_line.encode("utf-8"))
assert os.path.isfile("aclImdb/alldata-id.txt"), "alldata-id.txt unavailable"
print("Success, alldata-id.txt is available for next steps.")
The text data is small enough to be read into memory.
In [2]:
%%time
import gensim
from gensim.models.doc2vec import TaggedDocument
from collections import namedtuple
# this data object class suffices as a `TaggedDocument` (with `words` and `tags`)
# plus adds other state helpful for our later evaluation/reporting
SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment')
alldocs = []
with smart_open('aclImdb/alldata-id.txt', 'rb', encoding='utf-8') as alldata:
for line_no, line in enumerate(alldata):
tokens = gensim.utils.to_unicode(line).split()
words = tokens[1:]
tags = [line_no] # 'tags = [tokens[0]]' would also work at extra memory cost
split = ['train', 'test', 'extra', 'extra'][line_no//25000] # 25k train, 25k test, 25k extra
sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown
alldocs.append(SentimentDocument(words, tags, split, sentiment))
train_docs = [doc for doc in alldocs if doc.split == 'train']
test_docs = [doc for doc in alldocs if doc.split == 'test']
print('%d docs: %d train-sentiment, %d test-sentiment' % (len(alldocs), len(train_docs), len(test_docs)))
Because the native document-order has similar-sentiment documents in large clumps – which is suboptimal for training – we work with once-shuffled copy of the training set.
In [3]:
from random import shuffle
doc_list = alldocs[:]
shuffle(doc_list)
We approximate the experiment of Le & Mikolov "Distributed Representations of Sentences and Documents" with guidance from Mikolov's example go.sh:
./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1
We vary the following parameter choices:
cbow=0
means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with dm=0
dm_mean
) and one which concatenates them (dm_concat
, resulting in a much larger, slower, more data-hungry model)min_count=2
saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves)
In [4]:
%%time
from gensim.models import Doc2Vec
import gensim.models.doc2vec
from collections import OrderedDict
import multiprocessing
cores = multiprocessing.cpu_count()
assert gensim.models.doc2vec.FAST_VERSION > -1, "This will be painfully slow otherwise"
simple_models = [
# PV-DBOW plain
Doc2Vec(dm=0, vector_size=100, negative=5, hs=0, min_count=2, sample=0,
epochs=20, workers=cores),
# PV-DM w/ default averaging; a higher starting alpha may improve CBOW/PV-DM modes
Doc2Vec(dm=1, vector_size=100, window=10, negative=5, hs=0, min_count=2, sample=0,
epochs=20, workers=cores, alpha=0.05, comment='alpha=0.05'),
# PV-DM w/ concatenation - big, slow, experimental mode
# window=5 (both sides) approximates paper's apparent 10-word total window size
Doc2Vec(dm=1, dm_concat=1, vector_size=100, window=5, negative=5, hs=0, min_count=2, sample=0,
epochs=20, workers=cores),
]
for model in simple_models:
model.build_vocab(alldocs)
print("%s vocabulary scanned & state initialized" % model)
models_by_name = OrderedDict((str(model), model) for model in simple_models)
Le and Mikolov notes that combining a paragraph vector from Distributed Bag of Words (DBOW) and Distributed Memory (DM) improves performance. We will follow, pairing the models together for evaluation. Here, we concatenate the paragraph vectors obtained from each model with the help of a thin wrapper class included in a gensim test module. (Note that this a separate, later concatenation of output-vectors than the kind of input-window-concatenation enabled by the dm_concat=1
mode above.)
In [5]:
from gensim.test.test_doc2vec import ConcatenatedDoc2Vec
models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[1]])
models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[2]])
Let's define some helper methods for evaluating the performance of our Doc2vec using paragraph vectors. We will classify document sentiments using a logistic regression model based on our paragraph embeddings. We will compare the error rates based on word embeddings from our various Doc2vec models.
In [6]:
import numpy as np
import statsmodels.api as sm
from random import sample
def logistic_predictor_from_data(train_targets, train_regressors):
"""Fit a statsmodel logistic predictor on supplied data"""
logit = sm.Logit(train_targets, train_regressors)
predictor = logit.fit(disp=0)
# print(predictor.summary())
return predictor
def error_rate_for_model(test_model, train_set, test_set,
reinfer_train=False, reinfer_test=False,
infer_steps=None, infer_alpha=None, infer_subsample=0.2):
"""Report error rate on test_doc sentiments, using supplied model and train_docs"""
train_targets = [doc.sentiment for doc in train_set]
if reinfer_train:
train_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in train_set]
else:
train_regressors = [test_model.docvecs[doc.tags[0]] for doc in train_set]
train_regressors = sm.add_constant(train_regressors)
predictor = logistic_predictor_from_data(train_targets, train_regressors)
test_data = test_set
if reinfer_test:
if infer_subsample < 1.0:
test_data = sample(test_data, int(infer_subsample * len(test_data)))
test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data]
else:
test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs]
test_regressors = sm.add_constant(test_regressors)
# Predict & evaluate
test_predictions = predictor.predict(test_regressors)
corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data])
errors = len(test_predictions) - corrects
error_rate = float(errors) / len(test_predictions)
return (error_rate, errors, len(test_predictions), predictor)
Note that doc-vector training is occurring on all documents of the dataset, which includes all TRAIN/TEST/DEV docs.
We evaluate each model's sentiment predictive power based on error rate, and the evaluation is done for each model.
(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)
In [7]:
from collections import defaultdict
error_rates = defaultdict(lambda: 1.0) # To selectively print only best errors achieved
In [8]:
for model in simple_models:
print("Training %s" % model)
%time model.train(doc_list, total_examples=len(doc_list), epochs=model.epochs)
print("\nEvaluating %s" % model)
%time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)
error_rates[str(model)] = err_rate
print("\n%f %s\n" % (err_rate, model))
In [9]:
for model in [models_by_name['dbow+dmm'], models_by_name['dbow+dmc']]:
print("\nEvaluating %s" % model)
%time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)
error_rates[str(model)] = err_rate
print("\n%f %s\n" % (err_rate, model))
In [10]:
# Compare error rates achieved, best-to-worst
print("Err_rate Model")
for rate, name in sorted((rate, name) for name, rate in error_rates.items()):
print("%f %s" % (rate, name))
In our testing, contrary to the results of the paper, on this problem, PV-DBOW alone performs as good as anything else. Concatenating vectors from different models only sometimes offers a tiny predictive improvement – and stays generally close to the best-performing solo model included.
The best results achieved here are just around 10% error rate, still a long way from the paper's reported 7.42% error rate.
(Other trials not shown, with larger vectors and other changes, also don't come close to the paper's reported value. Others around the net have reported a similar inability to reproduce the paper's best numbers. The PV-DM/C mode improves a bit with many more training epochs – but doesn't reach parity with PV-DBOW.)
In [11]:
doc_id = np.random.randint(simple_models[0].docvecs.count) # Pick random doc; re-run cell for more examples
print('for doc %d...' % doc_id)
for model in simple_models:
inferred_docvec = model.infer_vector(alldocs[doc_id].words)
print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))
(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Defaults for inference may benefit from tuning for each dataset or model parameters.)
In [18]:
import random
doc_id = np.random.randint(simple_models[0].docvecs.count) # pick random doc, re-run cell for more examples
model = random.choice(simple_models) # and a random model
sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count) # get *all* similar documents
print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words)))
print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model)
for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:
print(u'%s %s: «%s»\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words)))
Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST... especially if the MOST has a cosine-similarity > 0.5. Re-run the cell to try another random target document.
In [13]:
word_models = simple_models[:]
In [23]:
import random
from IPython.display import HTML
# pick a random word with a suitable number of occurences
while True:
word = random.choice(word_models[0].wv.index2word)
if word_models[0].wv.vocab[word].count > 10:
break
# or uncomment below line, to just pick a word from the relevant domain:
#word = 'comedy/drama'
similars_per_model = [str(model.wv.most_similar(word, topn=20)).replace('), ','),<br>\n') for model in word_models]
similar_table = ("<table><tr><th>" +
"</th><th>".join([str(model) for model in word_models]) +
"</th></tr><tr><td>" +
"</td><td>".join(similars_per_model) +
"</td></tr></table>")
print("most similar words for '%s' (%d occurences)" % (word, simple_models[0].wv.vocab[word].count))
HTML(similar_table)
Out[23]:
Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the dbow_words=1
initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task, but may be appropriate on other tasks, or if you also need word-vectors.
Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word-vector training concurrent with doc-vector training.)
In [15]:
# grab the file if not already local
questions_filename = 'questions-words.txt'
if not os.path.isfile(questions_filename):
# Download IMDB archive
print("Downloading analogy questions file...")
url = u'https://raw.githubusercontent.com/tmikolov/word2vec/master/questions-words.txt'
r = requests.get(url)
with smart_open(questions_filename, 'wb') as f:
f.write(r.content)
assert os.path.isfile(questions_filename), "questions-words.txt unavailable"
print("Success, questions-words.txt is available for next steps.")
In [16]:
# Note: this analysis takes many minutes
for model in word_models:
score, sections = model.wv.evaluate_word_analogies('questions-words.txt')
correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])
print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))
Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/mean and DM/concat models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)
In [ ]:
This cell left intentionally erroneous.
Because the bulk-trained vectors had much of their training early, when the model itself was still settling, it is sometimes the case that rather than using the bulk-trained vectors, new vectors re-inferred from the final state of the model serve better as the input/test data for downstream tasks.
Our error_rate_for_model()
function already had a non-default option to re-infer vectors before training/testing the classifier, so here we test that option. (This takes as long or longer than initial bulk training, as inference is only single-threaded.)
In [24]:
for model in simple_models + [models_by_name['dbow+dmm'], models_by_name['dbow+dmc']]:
print("Evaluating %s re-inferred" % str(model))
pseudomodel_name = str(model)+"_reinferred"
%time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs, reinfer_train=True, reinfer_test=True, infer_subsample=1.0)
error_rates[pseudomodel_name] = err_rate
print("\n%f %s\n" % (err_rate, pseudomodel_name))
In [25]:
# Compare error rates achieved, best-to-worst
print("Err_rate Model")
for rate, name in sorted((rate, name) for name, rate in error_rates.items()):
print("%f %s" % (rate, name))
Here, we do not see much benefit of re-inference. It's more likely to help if the initial training used fewer epochs (10 is also a common value in the literature for larger datasets), or perhaps in larger datasets.
In [ ]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)
In [ ]:
%load_ext autoreload
%autoreload 2
In [ ]: