Gensim Doc2vec Tutorial on the IMDB Sentiment Dataset

Introduction

In this tutorial, we will learn how to apply Doc2vec using gensim by recreating the results of Le and Mikolov 2014.

Bag-of-words Model

Previous state-of-the-art document representations were based on the bag-of-words model, which represent input documents as a fixed-length vector. For example, borrowing from the Wikipedia article, the two documents
(1) John likes to watch movies. Mary likes movies too.
(2) John also likes to watch football games.
are used to construct a length 10 list of words
["John", "likes", "to", "watch", "movies", "Mary", "too", "also", "football", "games"]
so then we can represent the two documents as fixed length vectors whose elements are the frequencies of the corresponding words in our list
(1) [1, 2, 1, 1, 2, 1, 1, 0, 0, 0]
(2) [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]
Bag-of-words models are surprisingly effective but still lose information about word order. Bag of n-grams models consider word phrases of length n to represent documents as fixed-length vectors to capture local word order but suffer from data sparsity and high dimensionality.

Word2vec Model

Word2vec is a more recent model that embeds words in a high-dimensional vector space using a shallow neural network. The result is a set of word vectors where vectors close together in vector space have similar meanings based on context, and word vectors distant to each other have differing meanings. For example, strong and powerful would be close together and strong and Paris would be relatively far. There are two versions of this model based on skip-grams and continuous bag of words.

Word2vec - Skip-gram Model

The skip-gram word2vec model, for example, takes in pairs (word1, word2) generated by moving a window across text data, and trains a 1-hidden-layer neural network based on the fake task of given an input word, giving us a predicted probability distribution of nearby words to the input. The hidden-to-output weights in the neural network give us the word embeddings. So if the hidden layer has 300 neurons, this network will give us 300-dimensional word embeddings. We use one-hot encoding for the words.

Word2vec - Continuous-bag-of-words Model

Continuous-bag-of-words Word2vec is very similar to the skip-gram model. It is also a 1-hidden-layer neural network. The fake task is based on the input context words in a window around a center word, predict the center word. Again, the hidden-to-output weights give us the word embeddings and we use one-hot encoding.

Paragraph Vector

Le and Mikolov 2014 introduces the Paragraph Vector, which outperforms more naïve representations of documents such as averaging the Word2vec word vectors of a document. The idea is straightforward: we act as if a paragraph (or document) is just another vector like a word vector, but we will call it a paragraph vector. We determine the embedding of the paragraph in vector space in the same way as words. Our paragraph vector model considers local word order like bag of n-grams, but gives us a denser representation in vector space compared to a sparse, high-dimensional representation.

Paragraph Vector - Distributed Memory (PV-DM)

This is the Paragraph Vector model analogous to Continuous-bag-of-words Word2vec. The paragraph vectors are obtained by training a neural network on the fake task of inferring a center word based on context words and a context paragraph. A paragraph is a context for all words in the paragraph, and a word in a paragraph can have that paragraph as a context.

Paragraph Vector - Distributed Bag of Words (PV-DBOW)

This is the Paragraph Vector model analogous to Skip-gram Word2vec. The paragraph vectors are obtained by training a neural network on the fake task of predicting a probability distribution of words in a paragraph given a randomly-sampled word from the paragraph.

Requirements

The following python modules are dependencies for this tutorial:

  • testfixtures ( pip install testfixtures )
  • statsmodels ( pip install statsmodels )

Load corpus

Let's download the IMDB archive if it is not already downloaded (84 MB). This will be our text data for this tutorial.
The data can be found here: http://ai.stanford.edu/~amaas/data/sentiment/


In [1]:
import locale
import glob
import os.path
import requests
import tarfile
import sys
import codecs
import smart_open

dirname = 'aclImdb'
filename = 'aclImdb_v1.tar.gz'
locale.setlocale(locale.LC_ALL, 'C')

if sys.version > '3':
    control_chars = [chr(0x85)]
else:
    control_chars = [unichr(0x85)]

# Convert text to lower-case and strip punctuation/symbols from words
def normalize_text(text):
    norm_text = text.lower()
    # Replace breaks with spaces
    norm_text = norm_text.replace('<br />', ' ')
    # Pad punctuation with spaces on both sides
    for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']:
        norm_text = norm_text.replace(char, ' ' + char + ' ')
    return norm_text

import time
start = time.clock()

if not os.path.isfile('aclImdb/alldata-id.txt'):
    if not os.path.isdir(dirname):
        if not os.path.isfile(filename):
            # Download IMDB archive
            print("Downloading IMDB archive...")
            url = u'http://ai.stanford.edu/~amaas/data/sentiment/' + filename
            r = requests.get(url)
            with open(filename, 'wb') as f:
                f.write(r.content)
        tar = tarfile.open(filename, mode='r')
        tar.extractall()
        tar.close()

    # Concatenate and normalize test/train data
    print("Cleaning up dataset...")
    folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup']
    alldata = u''
    for fol in folders:
        temp = u''
        output = fol.replace('/', '-') + '.txt'
        # Is there a better pattern to use?
        txt_files = glob.glob(os.path.join(dirname, fol, '*.txt'))
        for txt in txt_files:
            with smart_open.smart_open(txt, "rb") as t:
                t_clean = t.read().decode("utf-8")
                for c in control_chars:
                    t_clean = t_clean.replace(c, ' ')
                temp += t_clean
            temp += "\n"
        temp_norm = normalize_text(temp)
        with smart_open.smart_open(os.path.join(dirname, output), "wb") as n:
            n.write(temp_norm.encode("utf-8"))
        alldata += temp_norm

    with smart_open.smart_open(os.path.join(dirname, 'alldata-id.txt'), 'wb') as f:
        for idx, line in enumerate(alldata.splitlines()):
            num_line = u"_*{0} {1}\n".format(idx, line)
            f.write(num_line.encode("utf-8"))

end = time.clock()
print ("Total running time: ", end-start)


Total running time:  0.00035199999999990794

In [2]:
import os.path
assert os.path.isfile("aclImdb/alldata-id.txt"), "alldata-id.txt unavailable"

The text data is small enough to be read into memory.


In [3]:
import gensim
from gensim.models.doc2vec import TaggedDocument
from collections import namedtuple

SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment')

alldocs = []  # Will hold all docs in original order
with open('aclImdb/alldata-id.txt', encoding='utf-8') as alldata:
    for line_no, line in enumerate(alldata):
        tokens = gensim.utils.to_unicode(line).split()
        words = tokens[1:]
        tags = [line_no] # 'tags = [tokens[0]]' would also work at extra memory cost
        split = ['train', 'test', 'extra', 'extra'][line_no//25000]  # 25k train, 25k test, 25k extra
        sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown
        alldocs.append(SentimentDocument(words, tags, split, sentiment))

train_docs = [doc for doc in alldocs if doc.split == 'train']
test_docs = [doc for doc in alldocs if doc.split == 'test']
doc_list = alldocs[:]  # For reshuffling per pass

print('%d docs: %d train-sentiment, %d test-sentiment' % (len(doc_list), len(train_docs), len(test_docs)))


100000 docs: 25000 train-sentiment, 25000 test-sentiment

Set-up Doc2Vec Training & Evaluation Models

We approximate the experiment of Le & Mikolov "Distributed Representations of Sentences and Documents" with guidance from Mikolov's example go.sh:

./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1

We vary the following parameter choices:

  • 100-dimensional vectors, as the 400-d vectors of the paper don't seem to offer much benefit on this task
  • Similarly, frequent word subsampling seems to decrease sentiment-prediction accuracy, so it's left out
  • cbow=0 means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with dm=0
  • Added to that DBOW model are two DM models, one which averages context vectors (dm_mean) and one which concatenates them (dm_concat, resulting in a much larger, slower, more data-hungry model)
  • A min_count=2 saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves)

In [4]:
from gensim.models import Doc2Vec
import gensim.models.doc2vec
from collections import OrderedDict
import multiprocessing

cores = multiprocessing.cpu_count()
assert gensim.models.doc2vec.FAST_VERSION > -1, "This will be painfully slow otherwise"

simple_models = [
    # PV-DM w/ concatenation - window=5 (both sides) approximates paper's 10-word total window size
    Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DBOW 
    Doc2Vec(dm=0, size=100, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DM w/ average
    Doc2Vec(dm=1, dm_mean=1, size=100, window=10, negative=5, hs=0, min_count=2, workers=cores),
]

# Speed up setup by sharing results of the 1st model's vocabulary scan
simple_models[0].build_vocab(alldocs)  # PV-DM w/ concat requires one special NULL word so it serves as template
print(simple_models[0])
for model in simple_models[1:]:
    model.reset_from(simple_models[0])
    print(model)

models_by_name = OrderedDict((str(model), model) for model in simple_models)


Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)
Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)
Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)

Le and Mikolov notes that combining a paragraph vector from Distributed Bag of Words (DBOW) and Distributed Memory (DM) improves performance. We will follow, pairing the models together for evaluation. Here, we concatenate the paragraph vectors obtained from each model.


In [5]:
from gensim.test.test_doc2vec import ConcatenatedDoc2Vec
models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[2]])
models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[0]])

Predictive Evaluation Methods

Let's define some helper methods for evaluating the performance of our Doc2vec using paragraph vectors. We will classify document sentiments using a logistic regression model based on our paragraph embeddings. We will compare the error rates based on word embeddings from our various Doc2vec models.


In [6]:
import numpy as np
import statsmodels.api as sm
from random import sample

# For timing
from contextlib import contextmanager
from timeit import default_timer
import time 

@contextmanager
def elapsed_timer():
    start = default_timer()
    elapser = lambda: default_timer() - start
    yield lambda: elapser()
    end = default_timer()
    elapser = lambda: end-start
    
def logistic_predictor_from_data(train_targets, train_regressors):
    logit = sm.Logit(train_targets, train_regressors)
    predictor = logit.fit(disp=0)
    # print(predictor.summary())
    return predictor

def error_rate_for_model(test_model, train_set, test_set, infer=False, infer_steps=3, infer_alpha=0.1, infer_subsample=0.1):
    """Report error rate on test_doc sentiments, using supplied model and train_docs"""

    train_targets, train_regressors = zip(*[(doc.sentiment, test_model.docvecs[doc.tags[0]]) for doc in train_set])
    train_regressors = sm.add_constant(train_regressors)
    predictor = logistic_predictor_from_data(train_targets, train_regressors)

    test_data = test_set
    if infer:
        if infer_subsample < 1.0:
            test_data = sample(test_data, int(infer_subsample * len(test_data)))
        test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data]
    else:
        test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs]
    test_regressors = sm.add_constant(test_regressors)
    
    # Predict & evaluate
    test_predictions = predictor.predict(test_regressors)
    corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data])
    errors = len(test_predictions) - corrects
    error_rate = float(errors) / len(test_predictions)
    return (error_rate, errors, len(test_predictions), predictor)


/Users/daniel/miniconda3/envs/gensim/lib/python3.6/site-packages/statsmodels/compat/pandas.py:56: FutureWarning: The pandas.core.datetools module is deprecated and will be removed in a future version. Please use the pandas.tseries module instead.
  from pandas.core import datetools

Bulk Training

We use an explicit multiple-pass, alpha-reduction approach as sketched in this gensim doc2vec blog post with added shuffling of corpus on each pass.

Note that vector training is occurring on all documents of the dataset, which includes all TRAIN/TEST/DEV docs.

We evaluate each model's sentiment predictive power based on error rate, and the evaluation is repeated after each pass so we can see the rates of relative improvement. The base numbers reuse the TRAIN and TEST vectors stored in the models for the logistic regression, while the inferred results use newly-inferred TEST vectors.

(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)


In [7]:
from collections import defaultdict
best_error = defaultdict(lambda: 1.0)  # To selectively print only best errors achieved

In [8]:
from random import shuffle
import datetime

alpha, min_alpha, passes = (0.025, 0.001, 20)
alpha_delta = (alpha - min_alpha) / passes

print("START %s" % datetime.datetime.now())

for epoch in range(passes):
    shuffle(doc_list)  # Shuffling gets best results
    
    for name, train_model in models_by_name.items():
        # Train
        duration = 'na'
        train_model.alpha, train_model.min_alpha = alpha, alpha
        with elapsed_timer() as elapsed:
            train_model.train(doc_list, total_examples=len(doc_list), epochs=1)
            duration = '%.1f' % elapsed()
            
        # Evaluate
        eval_duration = ''
        with elapsed_timer() as eval_elapsed:
            err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs)
        eval_duration = '%.1f' % eval_elapsed()
        best_indicator = ' '
        if err <= best_error[name]:
            best_error[name] = err
            best_indicator = '*' 
        print("%s%f : %i passes : %s %ss %ss" % (best_indicator, err, epoch + 1, name, duration, eval_duration))

        if ((epoch + 1) % 5) == 0 or epoch == 0:
            eval_duration = ''
            with elapsed_timer() as eval_elapsed:
                infer_err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs, infer=True)
            eval_duration = '%.1f' % eval_elapsed()
            best_indicator = ' '
            if infer_err < best_error[name + '_inferred']:
                best_error[name + '_inferred'] = infer_err
                best_indicator = '*'
            print("%s%f : %i passes : %s %ss %ss" % (best_indicator, infer_err, epoch + 1, name + '_inferred', duration, eval_duration))

    print('Completed pass %i at alpha %f' % (epoch + 1, alpha))
    alpha -= alpha_delta
    
print("END %s" % str(datetime.datetime.now()))


START 2017-07-08 17:48:01.470463
*0.404640 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 80.4s 2.3s
*0.361200 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 80.4s 10.9s
*0.247520 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.0s 1.1s
*0.201200 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 31.0s 3.5s
*0.264120 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 38.5s 0.7s
*0.203600 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 38.5s 4.7s
*0.216600 : 1 passes : dbow+dmm 0.0s 1.7s
*0.199600 : 1 passes : dbow+dmm_inferred 0.0s 10.6s
*0.244800 : 1 passes : dbow+dmc 0.0s 2.0s
*0.219600 : 1 passes : dbow+dmc_inferred 0.0s 15.0s
Completed pass 1 at alpha 0.025000
*0.349560 : 2 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 52.7s 0.6s
*0.147400 : 2 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 20.3s 0.5s
*0.209200 : 2 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 28.3s 0.5s
*0.140280 : 2 passes : dbow+dmm 0.0s 1.4s
*0.149360 : 2 passes : dbow+dmc 0.0s 2.2s
Completed pass 2 at alpha 0.023800
*0.308760 : 3 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 50.4s 0.6s
*0.126880 : 3 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 19.5s 0.5s
*0.192560 : 3 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 37.8s 0.7s
*0.124440 : 3 passes : dbow+dmm 0.0s 1.8s
*0.126280 : 3 passes : dbow+dmc 0.0s 1.7s
Completed pass 3 at alpha 0.022600
*0.277160 : 4 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 75.2s 0.7s
*0.119120 : 4 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.0s 2.6s
*0.177960 : 4 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 48.3s 0.8s
*0.118000 : 4 passes : dbow+dmm 0.0s 2.2s
*0.119400 : 4 passes : dbow+dmc 0.0s 2.0s
Completed pass 4 at alpha 0.021400
*0.256040 : 5 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 75.2s 0.8s
*0.256800 : 5 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 75.2s 9.0s
*0.115120 : 5 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 34.0s 1.6s
*0.115200 : 5 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 34.0s 3.5s
*0.171840 : 5 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 42.5s 0.9s
*0.202400 : 5 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 42.5s 6.2s
*0.111920 : 5 passes : dbow+dmm 0.0s 2.0s
*0.118000 : 5 passes : dbow+dmm_inferred 0.0s 11.6s
*0.113040 : 5 passes : dbow+dmc 0.0s 2.2s
*0.115600 : 5 passes : dbow+dmc_inferred 0.0s 17.3s
Completed pass 5 at alpha 0.020200
*0.236880 : 6 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 70.1s 2.0s
*0.109720 : 6 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 32.2s 0.9s
*0.166320 : 6 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 44.8s 0.9s
*0.108720 : 6 passes : dbow+dmm 0.0s 2.1s
*0.108480 : 6 passes : dbow+dmc 0.0s 2.0s
Completed pass 6 at alpha 0.019000
*0.221640 : 7 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 84.7s 0.9s
*0.107120 : 7 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.3s 1.9s
*0.164000 : 7 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 43.0s 0.9s
*0.106160 : 7 passes : dbow+dmm 0.0s 2.0s
*0.106680 : 7 passes : dbow+dmc 0.0s 2.0s
Completed pass 7 at alpha 0.017800
*0.209360 : 8 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 64.0s 0.8s
*0.106200 : 8 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.2s 0.8s
*0.161360 : 8 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 43.0s 0.9s
*0.104480 : 8 passes : dbow+dmm 0.0s 3.0s
*0.105640 : 8 passes : dbow+dmc 0.0s 2.0s
Completed pass 8 at alpha 0.016600
*0.203520 : 9 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 66.6s 1.0s
*0.105120 : 9 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 39.1s 1.1s
*0.160960 : 9 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 43.7s 0.7s
 0.104840 : 9 passes : dbow+dmm 0.0s 2.0s
*0.104240 : 9 passes : dbow+dmc 0.0s 2.0s
Completed pass 9 at alpha 0.015400
*0.195840 : 10 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 66.5s 1.7s
*0.197600 : 10 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 66.5s 10.1s
*0.104280 : 10 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.3s 0.8s
 0.115200 : 10 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 31.3s 4.7s
*0.158800 : 10 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 44.5s 0.9s
*0.182800 : 10 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 44.5s 6.3s
*0.102760 : 10 passes : dbow+dmm 0.0s 3.1s
*0.110000 : 10 passes : dbow+dmm_inferred 0.0s 11.3s
*0.103920 : 10 passes : dbow+dmc 0.0s 2.2s
*0.109200 : 10 passes : dbow+dmc_inferred 0.0s 16.4s
Completed pass 10 at alpha 0.014200
*0.190800 : 11 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 71.3s 1.0s
*0.103840 : 11 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 33.8s 0.8s
*0.157440 : 11 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 44.5s 0.9s
 0.103240 : 11 passes : dbow+dmm 0.0s 3.0s
 0.104360 : 11 passes : dbow+dmc 0.0s 2.1s
Completed pass 11 at alpha 0.013000
*0.188520 : 12 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 65.4s 0.8s
 0.104600 : 12 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 33.3s 1.0s
*0.157240 : 12 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 53.5s 1.7s
 0.103880 : 12 passes : dbow+dmm 0.0s 2.8s
 0.104640 : 12 passes : dbow+dmc 0.0s 2.6s
Completed pass 12 at alpha 0.011800
*0.185760 : 13 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 71.8s 1.7s
 0.104040 : 13 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.9s 1.0s
*0.155960 : 13 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 45.7s 0.8s
*0.102720 : 13 passes : dbow+dmm 0.0s 2.0s
 0.104120 : 13 passes : dbow+dmc 0.0s 1.9s
Completed pass 13 at alpha 0.010600
*0.181960 : 14 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 80.3s 0.8s
*0.103680 : 14 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 23.1s 0.7s
*0.155040 : 14 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 31.4s 1.5s
*0.102440 : 14 passes : dbow+dmm 0.0s 1.6s
*0.103680 : 14 passes : dbow+dmc 0.0s 1.7s
Completed pass 14 at alpha 0.009400
*0.180680 : 15 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 48.5s 0.7s
*0.186000 : 15 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 48.5s 12.0s
 0.104840 : 15 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 23.4s 0.7s
*0.101600 : 15 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 23.4s 4.3s
*0.154000 : 15 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 53.2s 2.0s
 0.191600 : 15 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 53.2s 4.8s
 0.102960 : 15 passes : dbow+dmm 0.0s 3.1s
*0.108400 : 15 passes : dbow+dmm_inferred 0.0s 11.4s
 0.104280 : 15 passes : dbow+dmc 0.0s 1.7s
*0.098400 : 15 passes : dbow+dmc_inferred 0.0s 14.1s
Completed pass 15 at alpha 0.008200
*0.180320 : 16 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 68.3s 1.0s
*0.103600 : 16 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 28.5s 2.1s
 0.154640 : 16 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 43.4s 0.7s
 0.102520 : 16 passes : dbow+dmm 0.0s 1.9s
*0.102480 : 16 passes : dbow+dmc 0.0s 2.9s
Completed pass 16 at alpha 0.007000
*0.178160 : 17 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 63.4s 2.0s
*0.103360 : 17 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.5s 0.8s
 0.154160 : 17 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 40.9s 1.0s
*0.102320 : 17 passes : dbow+dmm 0.0s 3.0s
 0.102680 : 17 passes : dbow+dmc 0.0s 2.0s
Completed pass 17 at alpha 0.005800
*0.177520 : 18 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 55.1s 0.8s
*0.103120 : 18 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 24.8s 0.7s
*0.153040 : 18 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 32.9s 0.8s
 0.102440 : 18 passes : dbow+dmm 0.0s 1.7s
*0.102480 : 18 passes : dbow+dmc 0.0s 2.6s
Completed pass 18 at alpha 0.004600
*0.177240 : 19 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 57.2s 1.5s
*0.103080 : 19 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 20.6s 1.8s
*0.152680 : 19 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 43.8s 0.8s
 0.102800 : 19 passes : dbow+dmm 0.0s 1.8s
 0.102600 : 19 passes : dbow+dmc 0.0s 1.7s
Completed pass 19 at alpha 0.003400
*0.176080 : 20 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 50.2s 0.6s
 0.188000 : 20 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 50.2s 8.5s
 0.103400 : 20 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 19.7s 0.7s
 0.111600 : 20 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 19.7s 4.1s
*0.152680 : 20 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 30.5s 0.6s
 0.182800 : 20 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 30.5s 4.7s
 0.102600 : 20 passes : dbow+dmm 0.0s 1.6s
 0.112800 : 20 passes : dbow+dmm_inferred 0.0s 8.8s
*0.102440 : 20 passes : dbow+dmc 0.0s 2.1s
 0.103600 : 20 passes : dbow+dmc_inferred 0.0s 12.4s
Completed pass 20 at alpha 0.002200
END 2017-07-08 18:39:42.878219

Achieved Sentiment-Prediction Accuracy


In [9]:
# Print best error rates achieved
print("Err rate Model")
for rate, name in sorted((rate, name) for name, rate in best_error.items()):
    print("%f %s" % (rate, name))


Err rate Model
0.098400 dbow+dmc_inferred
0.101600 Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred
0.102320 dbow+dmm
0.102440 dbow+dmc
0.103080 Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)
0.108400 dbow+dmm_inferred
0.152680 Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)
0.176080 Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)
0.182800 Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred
0.186000 Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred

In our testing, contrary to the results of the paper, PV-DBOW performs best. Concatenating vectors from different models only offers a small predictive improvement over averaging vectors. There best results reproduced are just under 10% error rate, still a long way from the paper's reported 7.42% error rate.

Examining Results

Are inferred vectors close to the precalculated ones?


In [10]:
doc_id = np.random.randint(simple_models[0].docvecs.count)  # Pick random doc; re-run cell for more examples
print('for doc %d...' % doc_id)
for model in simple_models:
    inferred_docvec = model.infer_vector(alldocs[doc_id].words)
    print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))


for doc 73872...
Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4):
 [(73872, 0.7427197694778442), (43744, 0.42404329776763916), (75113, 0.41938722133636475)]
Doc2Vec(dbow,d100,n5,mc2,s0.001,t4):
 [(73872, 0.9305995106697083), (64147, 0.6267511248588562), (80042, 0.6207213401794434)]
Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4):
 [(73872, 0.7893393039703369), (67773, 0.7167356014251709), (32802, 0.6937947273254395)]

(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Note the defaults for inference are very abbreviated – just 3 steps starting at a high alpha – and likely need tuning for other applications.)


In [11]:
import random

doc_id = np.random.randint(simple_models[0].docvecs.count)  # pick random doc, re-run cell for more examples
model = random.choice(simple_models)  # and a random model
sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count)  # get *all* similar documents
print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words)))
print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model)
for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:
    print(u'%s %s: «%s»\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words)))


TARGET (71919): «tweety is perched in his cage on the ledge and sylvester is across the street at the " bird watching society " building on about the same level . both are looking through binoculars , and they spot each other . tweety then utters his famous phrase , " i taught i taw a puddy cat . " ( thought i saw a pussy cat . ) sylvester scampers over to grab the bird . tweety flies out of his cage and granny comes to the rescue , bashing the cat and driving it away . the rest of the animated short shows a series of attempts by sylvester to grab tweetie - a familiar theme - and how either bad luck or granny thwarts him every time . the cat dons disguises and tries a number of clever schemes . . . all of which are funny and very entertaining . in all , a good cartoon and fun to watch .»

SIMILAR/DISSIMILAR DOCS PER MODEL Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4):

MOST (30440, 0.752430260181427): «in tweety's s . o . s , sylvester goes from picking garbage cans to being a stowaway on a cruise ship that happens to carry a certain canary bird-and granny , his owner . uh-oh ! once again , tweety and granny provide many obstacles to the cat's attempts to get the bird . sylvester also gets seasick quite a few times , too . and the second time the red-nosed feline goes to the place on the ship that has something that cures his ailments , tweety replaces it with nitroglycerin . so now sylvester can blow fire ! i'll stop here and say this is another excellent cartoon directed by friz freling starring the popular cat-and-bird duo . tweety's s . o . s is most highly recommended .»

MEDIAN (32141, 0.3800385594367981): «my entire family enjoyed this film , including 2 small children . great values without sex , violence , drugs , nudity , or profanity . also no zillion dollar special effects were added to try to misdirect viewers from a poorly written storyline . a simple little family fun movie . we especially like the songs in the movie . but we only got to hear a portion of the songs . . . mostly during the end credits . . . would love to buy a sound track cd from this movie . this is my 4th bill hillman movie and they all have the same guidelines as mentioned above . with all the movies out there that you don't want your kids to watch , this hillman fella has a no risk rating . we love his movies .»

LEAST (57712, -0.051298510283231735): «in a recent biography of burt lancaster , go tell the spartans is described as the best vietnam war film that nobody ever saw . hopefully with television and video products that will be corrected . i prefer to think of it as a prequel to platoon . this film is set in 1964 when america's participation was limited to advisers by this time raised to about 20 , 000 of them by president kennedy . whether if kennedy had lived and won a second term he would have increased our commitment to a half a million men as lyndon johnson did is open to much historical speculation . major burt lancaster heads such an advisory team with his number two captain marc singer . they get some replacements and a new assignment to build a fortress where the french tried years ago and failed . the replacements are a really mixed bag , a sergeant who lancaster has served with before and respects highly in jonathan goldsmith , a very green and eager second lieutenant in joe unger , a demolitions man who is a draftee and at that time vietnam service was a strictly volunteer thing in craig wasson , and a medic who is also a junkie in dennis howard . for one reason or another all of these get sent forward to build that outpost in a place that suddenly has acquired military significance . i said before this could be a prequel to platoon . platoon is set in the time a few years later when the usa was fully militarily committed in vietnam . platoon raises the same issues about the futility of that war , but i think go tell the spartans does a much better job . hard to bring your best effort into the fight since who and what you're fighting and fighting for seems to change weekly . originally this project was for william holden and i'm surprised holden passed on it . maybe for the better because lancaster strikes just the right note as the professional soldier in what was a backwater assignment who politics has passed over for promotion . knowing all that you will understand why lancaster makes the final decision he does . two others of note are evan kim who is the head of the south vietnamese regulars and interpreter who lancaster and company are training . he epitomizes the brutality of the struggle for us in a way that we can't appreciate from the other side because we never meet any of the viet cong by name . dolph sweet plays the general in charge of the american vietnam commitment , a general harnitz . he is closest to a real character because the general in charge their before johnson raised the troop levels and put in william westmoreland was paul harkins . joe unger is who i think gives the best performance as the shavetail lieutenant with all the conventional ideas of war and believes we have got to be with the good guys since we are americans . he learns fast that you issue uniforms for a reason and wars against people who don't have them are the most difficult . i think one could get a deep understanding of just what america faced in 1964 in vietnam by watching go tell the spartans .»

(Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST.)

Do the word vectors show useful similarities?


In [12]:
word_models = simple_models[:]

In [13]:
import random
from IPython.display import HTML
# pick a random word with a suitable number of occurences
while True:
    word = random.choice(word_models[0].wv.index2word)
    if word_models[0].wv.vocab[word].count > 10:
        break
# or uncomment below line, to just pick a word from the relevant domain:
#word = 'comedy/drama'
similars_per_model = [str(model.most_similar(word, topn=20)).replace('), ','),<br>\n') for model in word_models]
similar_table = ("<table><tr><th>" +
    "</th><th>".join([str(model) for model in word_models]) + 
    "</th></tr><tr><td>" +
    "</td><td>".join(similars_per_model) +
    "</td></tr></table>")
print("most similar words for '%s' (%d occurences)" % (word, simple_models[0].wv.vocab[word].count))
HTML(similar_table)


most similar words for 'thrilled' (276 occurences)
Out[13]:
Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)
[('pleased', 0.8135600090026855),
('excited', 0.7601636648178101),
('surprised', 0.7497514486312866),
('delighted', 0.740871012210846),
('impressed', 0.7300887107849121),
('disappointed', 0.715817391872406),
('shocked', 0.7109759449958801),
('intrigued', 0.7000594139099121),
('amazed', 0.6994709968566895),
('fascinated', 0.6952326893806458),
('saddened', 0.68060702085495),
('satisfied', 0.674963116645813),
('apprehensive', 0.6572576761245728),
('entertained', 0.654381275177002),
('disgusted', 0.6502282023429871),
('overjoyed', 0.6485082507133484),
('stunned', 0.6478738784790039),
('entranced', 0.6438385844230652),
('amused', 0.6437265872955322),
('dissappointed', 0.6427538394927979)]
[("ifans'", 0.44280144572257996),
('shay', 0.4335209131240845),
('crappers', 0.4007232189178467),
('overflow', 0.40028804540634155),
('yum', 0.3929170072078705),
("monkey'", 0.38661277294158936),
('kholi', 0.38401469588279724),
('fun-bloodbath', 0.38145124912261963),
('breathed', 0.373812735080719),
("eszterhas'", 0.3729144334793091),
('nob', 0.3723628520965576),
("meatloaf's", 0.3720172643661499),
('ruegger', 0.3683895468711853),
("haynes'", 0.36665791273117065),
('feigning', 0.36445197463035583),
('torches', 0.35865518450737),
('sirens', 0.3581739068031311),
('insides', 0.35690629482269287),
('swackhamer', 0.35603001713752747),
('trolls', 0.3526684641838074)]
[('pleased', 0.7576382160186768),
('excited', 0.7351139187812805),
('delighted', 0.7220871448516846),
('intrigued', 0.6748061180114746),
('surprised', 0.6552557945251465),
('shocked', 0.6505781412124634),
('disappointed', 0.6428648233413696),
('impressed', 0.6426182389259338),
('overjoyed', 0.6259098052978516),
('saddened', 0.6148285865783691),
('anxious', 0.6140503883361816),
('fascinated', 0.6126223802566528),
('skeptical', 0.6025052070617676),
('suprised', 0.5986943244934082),
('upset', 0.596437931060791),
('relieved', 0.593376874923706),
('psyched', 0.5923721790313721),
('captivated', 0.5753644704818726),
('astonished', 0.574415922164917),
('horrified', 0.5716636180877686)]

Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the dbow_words=1 initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task.

Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word vector training concurrent with doc vector training.)

Are the word vectors from this dataset any good at analogies?


In [14]:
# Download this file: https://github.com/nicholas-leonard/word2vec/blob/master/questions-words.txt
# and place it in the local directory
# Note: this takes many minutes
if os.path.isfile('question-words.txt'):
    for model in word_models:
        sections = model.accuracy('questions-words.txt')
        correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])
        print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))

Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/concat and DM/mean models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)

Slop


In [15]:
This cell left intentionally erroneous.

To mix the Google dataset (if locally available) into the word tests...


In [ ]:
from gensim.models import KeyedVectors
w2v_g100b = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
w2v_g100b.compact_name = 'w2v_g100b'
word_models.append(w2v_g100b)

To get copious logging output from above steps...


In [ ]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)

To auto-reload python code while developing...


In [ ]:
%load_ext autoreload
%autoreload 2