In [ ]:
%matplotlib inline
In [ ]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
This guide shows you how to reproduce the results of the paper by Le and
Mikolov 2014 <https://arxiv.org/pdf/1405.4053.pdf>
_ using Gensim. While the
entire paper is worth reading (it's only 9 pages), we will be focusing on
Section 3.2: "Beyond One Sentence - Sentiment Analysis with the IMDB
dataset".
This guide follows the following steps:
When examining results, we will look for answers for the following questions:
Our data for the tutorial will be the IMDB archive
<http://ai.stanford.edu/~amaas/data/sentiment/>
_.
If you're not familiar with this dataset, then here's a brief intro: it
contains several thousand movie reviews.
Each review is a single line of text containing multiple sentences, for example:
One of the best movie-dramas I have ever seen. We do a lot of acting in the
church and this is one that can be used as a resource that highlights all the
good things that actors can do in their work. I highly recommend this one,
especially for those who have an interest in acting, as a "must see."
These reviews will be the documents that we will work with in this tutorial. There are 100 thousand reviews in total.
Out of 100k reviews, 50k have a label: either positive (the reviewer liked the movie) or negative. The remaining 50k are unlabeled.
Our first task will be to prepare the dataset.
More specifically, we will:
First, let's define a convenient datatype for holding data for a single document:
list
of words.train
\ , test
or extra
. Determines how the document will be used (for training, testing, etc).This data type is helpful for later evaluation and reporting.
In particular, the index
member will help us quickly and easily retrieve the vectors for a document from a model.
In [ ]:
import collections
SentimentDocument = collections.namedtuple('SentimentDocument', 'words tags split sentiment')
We can now proceed with loading the corpus.
In [ ]:
import io
import re
import tarfile
import os.path
import smart_open
import gensim.utils
def download_dataset(url='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'):
fname = url.split('/')[-1]
if os.path.isfile(fname):
return fname
# Download the file to local storage first.
# We can't read it on the fly because of
# https://github.com/RaRe-Technologies/smart_open/issues/331
with smart_open.open(url, "rb", ignore_ext=True) as fin:
with smart_open.open(fname, 'wb', ignore_ext=True) as fout:
while True:
buf = fin.read(io.DEFAULT_BUFFER_SIZE)
if not buf:
break
fout.write(buf)
return fname
def create_sentiment_document(name, text, index):
_, split, sentiment_str, _ = name.split('/')
sentiment = {'pos': 1.0, 'neg': 0.0, 'unsup': None}[sentiment_str]
if sentiment is None:
split = 'extra'
tokens = gensim.utils.to_unicode(text).split()
return SentimentDocument(tokens, [index], split, sentiment)
def extract_documents():
fname = download_dataset()
index = 0
with tarfile.open(fname, mode='r:gz') as tar:
for member in tar.getmembers():
if re.match(r'aclImdb/(train|test)/(pos|neg|unsup)/\d+_\d+.txt$', member.name):
member_bytes = tar.extractfile(member).read()
member_text = member_bytes.decode('utf-8', errors='replace')
assert member_text.count('\n') == 0
yield create_sentiment_document(member.name, member_text, index)
index += 1
alldocs = list(extract_documents())
Here's what a single document looks like
In [ ]:
print(alldocs[27])
Extract our documents and split into training/test sets
In [ ]:
train_docs = [doc for doc in alldocs if doc.split == 'train']
test_docs = [doc for doc in alldocs if doc.split == 'test']
print('%d docs: %d train-sentiment, %d test-sentiment' % (len(alldocs), len(train_docs), len(test_docs)))
We approximate the experiment of Le & Mikolov "Distributed Representations
of Sentences and Documents"
<http://cs.stanford.edu/~quocle/paragraph_vector.pdf>
with guidance from
Mikolov's example go.sh
<https://groups.google.com/d/msg/word2vec-toolkit/Q49FIrNOQRo/J6KG8mUj45sJ>
::
./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1
We vary the following parameter choices:
cbow=0
means skip-gram which is equivalent to the paper's 'PV-DBOW'
mode, matched in gensim with dm=0
dm_mean
\ ) and one which concatenates them (\ dm_concat
\ ,
resulting in a much larger, slower, more data-hungry model)min_count=2
saves quite a bit of model memory, discarding only words
that appear in a single doc (and are thus no more expressive than the
unique-to-each doc vectors themselves)
In [ ]:
import multiprocessing
from collections import OrderedDict
import gensim.models.doc2vec
assert gensim.models.doc2vec.FAST_VERSION > -1, "This will be painfully slow otherwise"
from gensim.models.doc2vec import Doc2Vec
common_kwargs = dict(
vector_size=100, epochs=20, min_count=2,
sample=0, workers=multiprocessing.cpu_count(), negative=5, hs=0,
)
simple_models = [
# PV-DBOW plain
Doc2Vec(dm=0, **common_kwargs),
# PV-DM w/ default averaging; a higher starting alpha may improve CBOW/PV-DM modes
Doc2Vec(dm=1, window=10, alpha=0.05, comment='alpha=0.05', **common_kwargs),
# PV-DM w/ concatenation - big, slow, experimental mode
# window=5 (both sides) approximates paper's apparent 10-word total window size
Doc2Vec(dm=1, dm_concat=1, window=5, **common_kwargs),
]
for model in simple_models:
model.build_vocab(alldocs)
print("%s vocabulary scanned & state initialized" % model)
models_by_name = OrderedDict((str(model), model) for model in simple_models)
Le and Mikolov note that combining a paragraph vector from Distributed Bag of
Words (DBOW) and Distributed Memory (DM) improves performance. We will
follow, pairing the models together for evaluation. Here, we concatenate the
paragraph vectors obtained from each model with the help of a thin wrapper
class included in a gensim test module. (Note that this a separate, later
concatenation of output-vectors than the kind of input-window-concatenation
enabled by the dm_concat=1
mode above.)
In [ ]:
from gensim.test.test_doc2vec import ConcatenatedDoc2Vec
models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[1]])
models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[2]])
Given a document, our Doc2Vec
models output a vector representation of the document.
How useful is a particular model?
In case of sentiment analysis, we want the ouput vector to reflect the sentiment in the input document.
So, in vector space, positive documents should be distant from negative documents.
We train a logistic regression from the training set:
So, this logistic regression will be able to predict sentiment given a document vector.
Next, we test our logistic regression on the test set, and measure the rate of errors (incorrect predictions). If the document vectors from the Doc2Vec model reflect the actual sentiment well, the error rate will be low.
Therefore, the error rate of the logistic regression is indication of how well the given Doc2Vec model represents documents as vectors.
We can then compare different Doc2Vec
models by looking at their error rates.
In [ ]:
import numpy as np
import statsmodels.api as sm
from random import sample
def logistic_predictor_from_data(train_targets, train_regressors):
"""Fit a statsmodel logistic predictor on supplied data"""
logit = sm.Logit(train_targets, train_regressors)
predictor = logit.fit(disp=0)
# print(predictor.summary())
return predictor
def error_rate_for_model(test_model, train_set, test_set):
"""Report error rate on test_doc sentiments, using supplied model and train_docs"""
train_targets = [doc.sentiment for doc in train_set]
train_regressors = [test_model.docvecs[doc.tags[0]] for doc in train_set]
train_regressors = sm.add_constant(train_regressors)
predictor = logistic_predictor_from_data(train_targets, train_regressors)
test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_set]
test_regressors = sm.add_constant(test_regressors)
# Predict & evaluate
test_predictions = predictor.predict(test_regressors)
corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_set])
errors = len(test_predictions) - corrects
error_rate = float(errors) / len(test_predictions)
return (error_rate, errors, len(test_predictions), predictor)
Note that doc-vector training is occurring on all documents of the dataset, which includes all TRAIN/TEST/DEV docs. Because the native document-order has similar-sentiment documents in large clumps – which is suboptimal for training – we work with once-shuffled copy of the training set.
We evaluate each model's sentiment predictive power based on error rate, and the evaluation is done for each model.
(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)
In [ ]:
from collections import defaultdict
error_rates = defaultdict(lambda: 1.0) # To selectively print only best errors achieved
In [ ]:
from random import shuffle
shuffled_alldocs = alldocs[:]
shuffle(shuffled_alldocs)
for model in simple_models:
print("Training %s" % model)
model.train(shuffled_alldocs, total_examples=len(shuffled_alldocs), epochs=model.epochs)
print("\nEvaluating %s" % model)
err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)
error_rates[str(model)] = err_rate
print("\n%f %s\n" % (err_rate, model))
for model in [models_by_name['dbow+dmm'], models_by_name['dbow+dmc']]:
print("\nEvaluating %s" % model)
err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)
error_rates[str(model)] = err_rate
print("\n%f %s\n" % (err_rate, model))
In [ ]:
print("Err_rate Model")
for rate, name in sorted((rate, name) for name, rate in error_rates.items()):
print("%f %s" % (rate, name))
In our testing, contrary to the results of the paper, on this problem, PV-DBOW alone performs as good as anything else. Concatenating vectors from different models only sometimes offers a tiny predictive improvement – and stays generally close to the best-performing solo model included.
The best results achieved here are just around 10% error rate, still a long way from the paper's reported 7.42% error rate.
(Other trials not shown, with larger vectors and other changes, also don't come close to the paper's reported value. Others around the net have reported a similar inability to reproduce the paper's best numbers. The PV-DM/C mode improves a bit with many more training epochs – but doesn't reach parity with PV-DBOW.)
In [ ]:
doc_id = np.random.randint(simple_models[0].docvecs.count) # Pick random doc; re-run cell for more examples
print('for doc %d...' % doc_id)
for model in simple_models:
inferred_docvec = model.infer_vector(alldocs[doc_id].words)
print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))
(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Defaults for inference may benefit from tuning for each dataset or model parameters.)
In [ ]:
import random
doc_id = np.random.randint(simple_models[0].docvecs.count) # pick random doc, re-run cell for more examples
model = random.choice(simple_models) # and a random model
sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count) # get *all* similar documents
print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words)))
print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model)
for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:
s = sims[index]
i = sims[index][0]
words = ' '.join(alldocs[i].words)
print(u'%s %s: «%s»\n' % (label, s, words))
Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST... especially if the MOST has a cosine-similarity > 0.5. Re-run the cell to try another random target document.
In [ ]:
import random
word_models = simple_models[:]
def pick_random_word(model, threshold=10):
# pick a random word with a suitable number of occurences
while True:
word = random.choice(model.wv.index2word)
if model.wv.vocab[word].count > threshold:
return word
target_word = pick_random_word(word_models[0])
# or uncomment below line, to just pick a word from the relevant domain:
# target_word = 'comedy/drama'
for model in word_models:
print('target_word: %r model: %s similar words:' % (target_word, model))
for i, (word, sim) in enumerate(model.wv.most_similar(target_word, topn=10), 1):
print(' %d. %.2f %r' % (i, sim, word))
print()
Do the DBOW words look meaningless? That's because the gensim DBOW model
doesn't train word vectors – they remain at their random initialized values –
unless you ask with the dbow_words=1
initialization parameter. Concurrent
word-training slows DBOW mode significantly, and offers little improvement
(and sometimes a little worsening) of the error rate on this IMDB
sentiment-prediction task, but may be appropriate on other tasks, or if you
also need word-vectors.
Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word-vector training concurrent with doc-vector training.)
In [ ]:
# grab the file if not already local
questions_filename = 'questions-words.txt'
if not os.path.isfile(questions_filename):
# Download IMDB archive
print("Downloading analogy questions file...")
url = u'https://raw.githubusercontent.com/tmikolov/word2vec/master/questions-words.txt'
with smart_open.open(url, 'rb') as fin:
with smart_open.open(questions_filename, 'wb') as fout:
fout.write(fin.read())
assert os.path.isfile(questions_filename), "questions-words.txt unavailable"
print("Success, questions-words.txt is available for next steps.")
# Note: this analysis takes many minutes
for model in word_models:
score, sections = model.wv.evaluate_word_analogies('questions-words.txt')
correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])
print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))
Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/mean and DM/concat models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)