gensim doc2vec & IMDB sentiment dataset

TODO: section on introduction & motivation

TODO: prerequisites + dependencies (statsmodels, patsy, ?)

Requirements

Following are the dependencies for this tutorial:

- testfixtures
- statsmodels

Load corpus

Fetch and prep exactly as in Mikolov's go.sh shell script. (Note this cell tests for existence of required files, so steps won't repeat once the final summary file (aclImdb/alldata-id.txt) is available alongside this notebook.)


In [2]:
import locale
import glob
import os.path
import requests
import tarfile
import sys
import codecs
import smart_open

dirname = 'aclImdb'
filename = 'aclImdb_v1.tar.gz'
locale.setlocale(locale.LC_ALL, 'C')

if sys.version > '3':
    control_chars = [chr(0x85)]
else:
    control_chars = [unichr(0x85)]

# Convert text to lower-case and strip punctuation/symbols from words
def normalize_text(text):
    norm_text = text.lower()

    # Replace breaks with spaces
    norm_text = norm_text.replace('<br />', ' ')

    # Pad punctuation with spaces on both sides
    for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']:
        norm_text = norm_text.replace(char, ' ' + char + ' ')

    return norm_text

import time
start = time.clock()

if not os.path.isfile('aclImdb/alldata-id.txt'):
    if not os.path.isdir(dirname):
        if not os.path.isfile(filename):
            # Download IMDB archive
            url = u'http://ai.stanford.edu/~amaas/data/sentiment/' + filename
            r = requests.get(url)
            with open(filename, 'wb') as f:
                f.write(r.content)

        tar = tarfile.open(filename, mode='r')
        tar.extractall()
        tar.close()

    # Concat and normalize test/train data
    folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup']
    alldata = u''

    for fol in folders:
        temp = u''
        output = fol.replace('/', '-') + '.txt'

        # Is there a better pattern to use?
        txt_files = glob.glob(os.path.join(dirname, fol, '*.txt'))

        for txt in txt_files:
            with smart_open.smart_open(txt, "rb") as t:
                t_clean = t.read().decode("utf-8")

                for c in control_chars:
                    t_clean = t_clean.replace(c, ' ')

                temp += t_clean

            temp += "\n"

        temp_norm = normalize_text(temp)

        with smart_open.smart_open(os.path.join(dirname, output), "wb") as n:
            n.write(temp_norm.encode("utf-8"))

        alldata += temp_norm

    with smart_open.smart_open(os.path.join(dirname, 'alldata-id.txt'), 'wb') as f:
        for idx, line in enumerate(alldata.splitlines()):
            num_line = u"_*{0} {1}\n".format(idx, line)
            f.write(num_line.encode("utf-8"))

end = time.clock()
print ("total running time: ", end-start)


total running time:  41.018378

In [2]:
import os.path
assert os.path.isfile("aclImdb/alldata-id.txt"), "alldata-id.txt unavailable"

The data is small enough to be read into memory.


In [3]:
import gensim
from gensim.models.doc2vec import TaggedDocument
from collections import namedtuple

SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment')

alldocs = []  # will hold all docs in original order
with open('aclImdb/alldata-id.txt', encoding='utf-8') as alldata:
    for line_no, line in enumerate(alldata):
        tokens = gensim.utils.to_unicode(line).split()
        words = tokens[1:]
        tags = [line_no] # `tags = [tokens[0]]` would also work at extra memory cost
        split = ['train','test','extra','extra'][line_no//25000]  # 25k train, 25k test, 25k extra
        sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown
        alldocs.append(SentimentDocument(words, tags, split, sentiment))

train_docs = [doc for doc in alldocs if doc.split == 'train']
test_docs = [doc for doc in alldocs if doc.split == 'test']
doc_list = alldocs[:]  # for reshuffling per pass

print('%d docs: %d train-sentiment, %d test-sentiment' % (len(doc_list), len(train_docs), len(test_docs)))


100000 docs: 25000 train-sentiment, 25000 test-sentiment

Set-up Doc2Vec Training & Evaluation Models

Approximating experiment of Le & Mikolov "Distributed Representations of Sentences and Documents", also with guidance from Mikolov's example go.sh:

./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1

Parameter choices below vary:

  • 100-dimensional vectors, as the 400d vectors of the paper don't seem to offer much benefit on this task
  • similarly, frequent word subsampling seems to decrease sentiment-prediction accuracy, so it's left out
  • cbow=0 means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with dm=0
  • added to that DBOW model are two DM models, one which averages context vectors (dm_mean) and one which concatenates them (dm_concat, resulting in a much larger, slower, more data-hungry model)
  • a min_count=2 saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves)

In [4]:
from gensim.models import Doc2Vec
import gensim.models.doc2vec
from collections import OrderedDict
import multiprocessing

cores = multiprocessing.cpu_count()
assert gensim.models.doc2vec.FAST_VERSION > -1, "this will be painfully slow otherwise"

simple_models = [
    # PV-DM w/concatenation - window=5 (both sides) approximates paper's 10-word total window size
    Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DBOW 
    Doc2Vec(dm=0, size=100, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DM w/average
    Doc2Vec(dm=1, dm_mean=1, size=100, window=10, negative=5, hs=0, min_count=2, workers=cores),
]

# speed setup by sharing results of 1st model's vocabulary scan
simple_models[0].build_vocab(alldocs)  # PV-DM/concat requires one special NULL word so it serves as template
print(simple_models[0])
for model in simple_models[1:]:
    model.reset_from(simple_models[0])
    print(model)

models_by_name = OrderedDict((str(model), model) for model in simple_models)


Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)
Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)
Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)

Following the paper, we also evaluate models in pairs. These wrappers return the concatenation of the vectors from each model. (Only the singular models are trained.)


In [5]:
from gensim.test.test_doc2vec import ConcatenatedDoc2Vec
models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[2]])
models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[0]])

Predictive Evaluation Methods

Helper methods for evaluating error rate.


In [6]:
import numpy as np
import statsmodels.api as sm
from random import sample

# for timing
from contextlib import contextmanager
from timeit import default_timer
import time 

@contextmanager
def elapsed_timer():
    start = default_timer()
    elapser = lambda: default_timer() - start
    yield lambda: elapser()
    end = default_timer()
    elapser = lambda: end-start
    
def logistic_predictor_from_data(train_targets, train_regressors):
    logit = sm.Logit(train_targets, train_regressors)
    predictor = logit.fit(disp=0)
    #print(predictor.summary())
    return predictor

def error_rate_for_model(test_model, train_set, test_set, infer=False, infer_steps=3, infer_alpha=0.1, infer_subsample=0.1):
    """Report error rate on test_doc sentiments, using supplied model and train_docs"""

    train_targets, train_regressors = zip(*[(doc.sentiment, test_model.docvecs[doc.tags[0]]) for doc in train_set])
    train_regressors = sm.add_constant(train_regressors)
    predictor = logistic_predictor_from_data(train_targets, train_regressors)

    test_data = test_set
    if infer:
        if infer_subsample < 1.0:
            test_data = sample(test_data, int(infer_subsample * len(test_data)))
        test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data]
    else:
        test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs]
    test_regressors = sm.add_constant(test_regressors)
    
    # predict & evaluate
    test_predictions = predictor.predict(test_regressors)
    corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data])
    errors = len(test_predictions) - corrects
    error_rate = float(errors) / len(test_predictions)
    return (error_rate, errors, len(test_predictions), predictor)


/usr/lib/python3.4/importlib/_bootstrap.py:321: FutureWarning: The pandas.core.datetools module is deprecated and will be removed in a future version. Please use the pandas.tseries module instead.
  return f(*args, **kwds)

Bulk Training

Using explicit multiple-pass, alpha-reduction approach as sketched in gensim doc2vec blog post – with added shuffling of corpus on each pass.

Note that vector training is occurring on all documents of the dataset, which includes all TRAIN/TEST/DEV docs.

Evaluation of each model's sentiment-predictive power is repeated after each pass, as an error rate (lower is better), to see the rates-of-relative-improvement. The base numbers reuse the TRAIN and TEST vectors stored in the models for the logistic regression, while the inferred results use newly-inferred TEST vectors.

(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)


In [7]:
from collections import defaultdict
best_error = defaultdict(lambda :1.0)  # to selectively-print only best errors achieved

In [8]:
from random import shuffle
import datetime

alpha, min_alpha, passes = (0.025, 0.001, 20)
alpha_delta = (alpha - min_alpha) / passes

print("START %s" % datetime.datetime.now())

for epoch in range(passes):
    shuffle(doc_list)  # shuffling gets best results
    
    for name, train_model in models_by_name.items():
        # train
        duration = 'na'
        train_model.alpha, train_model.min_alpha = alpha, alpha
        with elapsed_timer() as elapsed:
            train_model.train(doc_list, total_examples=len(doc_list), epochs=1)
            duration = '%.1f' % elapsed()
            
        # evaluate
        eval_duration = ''
        with elapsed_timer() as eval_elapsed:
            err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs)
        eval_duration = '%.1f' % eval_elapsed()
        best_indicator = ' '
        if err <= best_error[name]:
            best_error[name] = err
            best_indicator = '*' 
        print("%s%f : %i passes : %s %ss %ss" % (best_indicator, err, epoch + 1, name, duration, eval_duration))

        if ((epoch + 1) % 5) == 0 or epoch == 0:
            eval_duration = ''
            with elapsed_timer() as eval_elapsed:
                infer_err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs, infer=True)
            eval_duration = '%.1f' % eval_elapsed()
            best_indicator = ' '
            if infer_err < best_error[name + '_inferred']:
                best_error[name + '_inferred'] = infer_err
                best_indicator = '*'
            print("%s%f : %i passes : %s %ss %ss" % (best_indicator, infer_err, epoch + 1, name + '_inferred', duration, eval_duration))

    print('completed pass %i at alpha %f' % (epoch + 1, alpha))
    alpha -= alpha_delta
    
print("END %s" % str(datetime.datetime.now()))


START 2017-06-06 15:19:50.208091
*0.408320 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 131.9s 33.6s
*0.341600 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 131.9s 48.3s
*0.239960 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 35.3s 45.9s
*0.193200 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 35.3s 48.3s
*0.268640 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 48.6s 48.5s
*0.208000 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 48.6s 47.4s
*0.216160 : 1 passes : dbow+dmm 0.0s 168.9s
*0.176000 : 1 passes : dbow+dmm_inferred 0.0s 176.4s
*0.237280 : 1 passes : dbow+dmc 0.0s 169.3s
*0.194400 : 1 passes : dbow+dmc_inferred 0.0s 183.9s
completed pass 1 at alpha 0.025000
*0.346760 : 2 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 133.4s 42.2s
*0.145280 : 2 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 29.0s 42.8s
*0.210920 : 2 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 38.8s 42.2s
*0.139120 : 2 passes : dbow+dmm 0.0s 173.2s
*0.147120 : 2 passes : dbow+dmc 0.0s 191.8s
completed pass 2 at alpha 0.023800
*0.314920 : 3 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 112.3s 37.6s
*0.126720 : 3 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 28.4s 42.6s
*0.191920 : 3 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 37.9s 42.2s
*0.121640 : 3 passes : dbow+dmm 0.0s 190.8s
*0.127040 : 3 passes : dbow+dmc 0.0s 188.1s
completed pass 3 at alpha 0.022600
*0.282080 : 4 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 104.9s 36.3s
*0.115520 : 4 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 27.6s 49.9s
*0.181280 : 4 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 40.7s 42.2s
*0.114760 : 4 passes : dbow+dmm 0.0s 188.6s
*0.116040 : 4 passes : dbow+dmc 0.0s 192.5s
completed pass 4 at alpha 0.021400
*0.257560 : 5 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 102.5s 35.8s
*0.265200 : 5 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 102.5s 48.6s
*0.110880 : 5 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 27.0s 46.5s
*0.117600 : 5 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 27.0s 50.5s
*0.171240 : 5 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 39.1s 43.7s
*0.207200 : 5 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 39.1s 47.5s
*0.108920 : 5 passes : dbow+dmm 0.0s 203.4s
*0.114800 : 5 passes : dbow+dmm_inferred 0.0s 213.4s
*0.111520 : 5 passes : dbow+dmc 0.0s 189.5s
*0.132000 : 5 passes : dbow+dmc_inferred 0.0s 202.6s
completed pass 5 at alpha 0.020200
*0.240440 : 6 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 117.6s 39.2s
*0.107600 : 6 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 32.3s 52.1s
*0.166800 : 6 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 46.4s 40.8s
*0.108160 : 6 passes : dbow+dmm 0.0s 197.8s
*0.109920 : 6 passes : dbow+dmc 0.0s 189.4s
completed pass 6 at alpha 0.019000
*0.225280 : 7 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 102.8s 36.0s
*0.105560 : 7 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 31.0s 47.0s
*0.164320 : 7 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 38.6s 43.7s
*0.104760 : 7 passes : dbow+dmm 0.0s 187.1s
*0.107600 : 7 passes : dbow+dmc 0.0s 182.9s
completed pass 7 at alpha 0.017800
*0.214280 : 8 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 99.2s 41.1s
*0.102400 : 8 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 28.6s 47.3s
*0.161000 : 8 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 36.4s 40.9s
*0.102720 : 8 passes : dbow+dmm 0.0s 188.2s
*0.104280 : 8 passes : dbow+dmc 0.0s 187.3s
completed pass 8 at alpha 0.016600
*0.206840 : 9 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 96.9s 41.4s
 0.102920 : 9 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 27.1s 46.4s
*0.158600 : 9 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 40.3s 40.7s
*0.101880 : 9 passes : dbow+dmm 0.0s 188.1s
*0.103960 : 9 passes : dbow+dmc 0.0s 192.2s
completed pass 9 at alpha 0.015400
*0.198960 : 10 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 116.0s 43.0s
*0.194000 : 10 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 116.0s 54.2s
*0.102120 : 10 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 27.8s 47.1s
*0.100000 : 10 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 27.8s 50.4s
*0.156640 : 10 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 38.3s 41.9s
*0.178400 : 10 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 38.3s 46.8s
 0.102520 : 10 passes : dbow+dmm 0.0s 192.5s
*0.104000 : 10 passes : dbow+dmm_inferred 0.0s 207.3s
*0.103560 : 10 passes : dbow+dmc 0.0s 191.0s
*0.115200 : 10 passes : dbow+dmc_inferred 0.0s 203.5s
completed pass 10 at alpha 0.014200
*0.192000 : 11 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 97.3s 42.7s
 0.102840 : 11 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.8s 45.1s
 0.156680 : 11 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 36.9s 41.1s
*0.101600 : 11 passes : dbow+dmm 0.0s 187.8s
 0.103880 : 11 passes : dbow+dmc 0.0s 187.9s
completed pass 11 at alpha 0.013000
*0.190440 : 12 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 99.1s 44.5s
 0.103640 : 12 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 34.7s 45.9s
*0.154640 : 12 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 37.3s 41.8s
 0.103400 : 12 passes : dbow+dmm 0.0s 190.1s
 0.103640 : 12 passes : dbow+dmc 0.0s 190.6s
completed pass 12 at alpha 0.011800
*0.186840 : 13 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 99.1s 41.0s
 0.102560 : 13 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.7s 44.5s
*0.153880 : 13 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.9s 40.0s
 0.103760 : 13 passes : dbow+dmm 0.0s 182.8s
 0.103680 : 13 passes : dbow+dmc 0.0s 174.8s
completed pass 13 at alpha 0.010600
*0.184600 : 14 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 92.0s 38.6s
 0.103080 : 14 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.7s 44.5s
*0.153760 : 14 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.8s 39.0s
 0.103120 : 14 passes : dbow+dmm 0.0s 177.6s
 0.103960 : 14 passes : dbow+dmc 0.0s 176.0s
completed pass 14 at alpha 0.009400
*0.182720 : 15 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 91.7s 38.7s
*0.179600 : 15 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 91.7s 50.8s
 0.103280 : 15 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.7s 43.5s
 0.104400 : 15 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 26.7s 47.8s
*0.153720 : 15 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 36.0s 39.0s
 0.187200 : 15 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 36.0s 43.7s
 0.103520 : 15 passes : dbow+dmm 0.0s 174.9s
 0.105600 : 15 passes : dbow+dmm_inferred 0.0s 183.2s
 0.103680 : 15 passes : dbow+dmc 0.0s 175.9s
*0.106000 : 15 passes : dbow+dmc_inferred 0.0s 189.9s
completed pass 15 at alpha 0.008200
*0.181040 : 16 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 91.6s 41.2s
 0.103240 : 16 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.7s 45.3s
*0.153600 : 16 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 36.1s 40.6s
 0.103960 : 16 passes : dbow+dmm 0.0s 175.9s
*0.103400 : 16 passes : dbow+dmc 0.0s 175.9s
completed pass 16 at alpha 0.007000
*0.180080 : 17 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 92.1s 40.3s
 0.102760 : 17 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.3s 44.9s
*0.152880 : 17 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.4s 39.0s
 0.103200 : 17 passes : dbow+dmm 0.0s 182.5s
*0.103280 : 17 passes : dbow+dmc 0.0s 178.0s
completed pass 17 at alpha 0.005800
*0.178720 : 18 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 91.1s 39.0s
*0.101640 : 18 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.4s 44.3s
*0.152280 : 18 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.6s 39.5s
 0.102360 : 18 passes : dbow+dmm 0.0s 183.8s
 0.103320 : 18 passes : dbow+dmc 0.0s 179.0s
completed pass 18 at alpha 0.004600
*0.178600 : 19 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 91.1s 38.9s
 0.102320 : 19 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.3s 45.7s
*0.151920 : 19 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.5s 40.7s
 0.102240 : 19 passes : dbow+dmm 0.0s 181.7s
*0.103000 : 19 passes : dbow+dmc 0.0s 181.7s
completed pass 19 at alpha 0.003400
*0.177360 : 20 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4) 90.9s 40.0s
 0.190800 : 20 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred 90.9s 52.1s
 0.102520 : 20 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4) 26.4s 45.2s
 0.108800 : 20 passes : Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred 26.4s 48.7s
*0.151680 : 20 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4) 35.5s 40.8s
 0.182400 : 20 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred 35.5s 45.3s
 0.102320 : 20 passes : dbow+dmm 0.0s 183.5s
 0.113200 : 20 passes : dbow+dmm_inferred 0.0s 192.3s
*0.102800 : 20 passes : dbow+dmc 0.0s 183.3s
 0.111200 : 20 passes : dbow+dmc_inferred 0.0s 196.1s
completed pass 20 at alpha 0.002200
END 2017-06-06 19:46:10.508929

Achieved Sentiment-Prediction Accuracy


In [9]:
# print best error rates achieved
for rate, name in sorted((rate, name) for name, rate in best_error.items()):
    print("%f %s" % (rate, name))


0.100000 Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)_inferred
0.101600 dbow+dmm
0.101640 Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)
0.102800 dbow+dmc
0.104000 dbow+dmm_inferred
0.106000 dbow+dmc_inferred
0.151680 Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)
0.177360 Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)
0.178400 Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)_inferred
0.179600 Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)_inferred

In my testing, unlike the paper's report, DBOW performs best. Concatenating vectors from different models only offers a small predictive improvement. The best results I've seen are still just under 10% error rate, still a ways from the paper's 7.42%.

Examining Results

Are inferred vectors close to the precalculated ones?


In [10]:
doc_id = np.random.randint(simple_models[0].docvecs.count)  # pick random doc; re-run cell for more examples
print('for doc %d...' % doc_id)
for model in simple_models:
    inferred_docvec = model.infer_vector(alldocs[doc_id].words)
    print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))


for doc 47495...
Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4):
 [(47495, 0.8063223361968994), (28683, 0.4661555588245392), (10030, 0.3962923586368561)]
Doc2Vec(dbow,d100,n5,mc2,s0.001,t4):
 [(47495, 0.9660482406616211), (17469, 0.5925078392028809), (52349, 0.5742233991622925)]
Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4):
 [(47495, 0.8801028728485107), (60782, 0.5431949496269226), (42472, 0.5375599265098572)]

(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Note the defaults for inference are very abbreviated – just 3 steps starting at a high alpha – and likely need tuning for other applications.)


In [11]:
import random

doc_id = np.random.randint(simple_models[0].docvecs.count)  # pick random doc, re-run cell for more examples
model = random.choice(simple_models)  # and a random model
sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count)  # get *all* similar documents
print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words)))
print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model)
for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:
    print(u'%s %s: «%s»\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words)))


TARGET (43375): «the film " chaos " takes its name from gleick's 1988 pop science explanation of chaos theory . what does the book or anything related to the content of the book have to do with the plot of the movie " chaos " ? nothing . the film makers seem to have skimmed the book ( obviously without understanding a thing about it ) looking for a " theme " to united the series of mundane action sequences that overlie the flimsy string of events that acts in place of a plot in the film . in this respect , the movie " choas " resembles the canadian effort " cube , " in which prime numbers function as a device to mystify the audience so that the ridiculousness of the plot will not be noticed : in " cube " a bunch of prime numbers are tossed in so that viewers will attribute their lack of understanding to lack of knowledge about primes : the same approach is taken in " chaos " : disconnected extracts from gleick's books are thrown in make the doings of the bad guy in the film seem fiendishly clever . this , of course , is an insultingly condescending treatment of the audience , and any literate viewer of " chaos " who can stand to sit through the entire film will end up bewildered . how could a film so bad be made ? rewritten as a novel , the story in " chaos " would probably not even make it past a literary agent's secretary's desk . how could ( at least ) hundreds of thousands ( and probably millions ) of dollars have been thrown away on what can only be considered a waste of time for everyone except those who took home money from the film ? regarding what's in the movie , every performance is phoned in . save for technical glitches , it would be astonishing if more than one take was used for any one scene . the story is uniformly senseless : the last time i saw a story to disconnected it was the production of a literal eight-year-old . among other massive shortcomings are the following : the bad guy leaves hints for the police to follow . he has no reason whatsoever for leaving such hints . police officers do not carry or use radios . dupes of the bad guy have no reason to act in concert with the bad guy . let me strongly recommend that no one watch this film . if there is any other movie you like ( or even simply do not hate ) watch that instead .»

SIMILAR/DISSIMILAR DOCS PER MODEL Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4):

MOST (48890, 0.5806792378425598): «asmali konak has arguably become one of the best tv series to come out of turkey . with its unique cinematography and visual approach to filming , the series has gained a wide following base with rating records continuously broken . personally i do not agree with singers becoming actors ( hence , ozcan deniz - the lead actor ) but i guess the figures speak for themselves . in relation to the movie , it was disgusting to see how much someone can destroy such a plotline . years in the making , this movie was able to oversee every descent story that existed within the series . not only that , the cultural mistakes were unacceptable , with an idiotic scene involving the family members dancing ( greek style ) and breaking plates , which does not exists anywhere within the turkish culture . some argue the movie should be taken as a stand alone movie not as a continuation of the tv series but this theory has one major fall , the way the movie was marketed was that it will be picking up where the series left off and will conclude the series once and for all . so with that note in mind , me and everyone i know , would have asked for a refund and accepted to stand outside the theatre to warn other victims .»

MEDIAN (93452, 0.22335509955883026): «this is the second film ( dead men walking ) set in a prison by theasylum . the mythos behind plot is very good , russian mafia has this demon do there dirty work and the rainbow array of inmates have to defend their bars & mortar . jennifer lee ( see interview ) wiggins stars as a prison guard who has a inmate , who maybe a demon . the monster suit is awesome and frightening , and a different look that almost smacks of a toy franchise , hey if full moon and todd mcfarlane can make action figures for any character . . why not the beast from bray road wolfette , shapeshifter with medallion accessory , or the rhett giles everyman hero with removable appendages .»

LEAST (57989, -0.22353392839431763): «saw this movie on re-run just once , when i was about 13 , in 1980 . it completely matched my teenaged fantasies of sweet , gentle , interesting — and let's face it — hot — " older " guys . just ordered it from cd universe about a month ago , and have given it about four whirls in the two weeks since . as somebody mentioned — i'm haunted by it . as somebody else mentioned — i think it's part of a midlife crisis as well ! being 39 and realizing how much has changed since those simpler '70s times when girls of 13 actually did take buses and go to malls together and had a lot more freedom away from the confines of modern suburbia makes me sad for my daughter — who is nearly 13 herself . thirteen back then was in many ways a lot more grown up . the film is definitely '70s but not in a super-dated cheesy way , in fact the outfits denise miller as jessie wears could be current now ! you know what they say , everything that goes around . . . although the short-short jogging shorts worn by rex with the to-the-knees sweat socks probably won't make a comeback . the subject matter is handled in a very sensitive way and the characters are treated with a lot of respect . it's not the most chatty movie going — i often wished for more to be said between jessie and michael that would cement why he was also attracted to her . but the acting is solid , the movie is sweet and atmospheric , and the fringe characters give great performances . mary beth manning as jessie's friend caroline is a total hoot — i think we all had friends like her . maia danziger as the relentless flirt with michael gives a wiggy , stoned-out performance that just makes you laugh — because we also all knew girls that acted like that . denise miller knocked her performance out of the ballpark with a very down-to-earth quality likely credited to her uknown status and being new to the industry . and i think not a little of the credit for the film's theatre-grade quality comes from the very capable , brilliant hands of the story's authors , carole and the late bruce hart , who also wrote for sesame street . they really cared about the message of the movie , which was not an overt in-your-face thing , while at the same time understanding how eager many girls are to grow up at that age . one thing that made me love the film then as much as now is not taking the cliché , easy , tied-with-a-bow but sort of let-down ending . in fact it's probably the end that has caused so many women to return to viewing the movie in their later years . re-watching sooner or later has me absolutely sick with nostalgia for those simpler times , and has triggered a ridiculous and sudden obsession with catching up with rex smith — whom while i enjoyed his albums sooner or later and forever when i was young , i never plastered his posters on my walls as i did some of my other faves . in the past week , i've put his music on my ipod , read fan sites , found interviews ( and marveled in just how brilliant he really is — the man has a fascinating way of thinking ) , watched clips on youtube — what am i , 13 ? i guess that's the biggest appeal of this movie . remembering what it was like to be 13 and the whole world was ahead of you .»

(Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST.)

Do the word vectors show useful similarities?


In [12]:
word_models = simple_models[:]

In [13]:
import random
from IPython.display import HTML
# pick a random word with a suitable number of occurences
while True:
    word = random.choice(word_models[0].wv.index2word)
    if word_models[0].wv.vocab[word].count > 10:
        break
# or uncomment below line, to just pick a word from the relevant domain:
#word = 'comedy/drama'
similars_per_model = [str(model.most_similar(word, topn=20)).replace('), ','),<br>\n') for model in word_models]
similar_table = ("<table><tr><th>" +
    "</th><th>".join([str(model) for model in word_models]) + 
    "</th></tr><tr><td>" +
    "</td><td>".join(similars_per_model) +
    "</td></tr></table>")
print("most similar words for '%s' (%d occurences)" % (word, simple_models[0].wv.vocab[word].count))
HTML(similar_table)


most similar words for 'gymnast' (36 occurences)
Out[13]:
Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4)Doc2Vec(dbow,d100,n5,mc2,s0.001,t4)Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4)
[('scientist', 0.530441164970398),
('psychotherapist', 0.527083694934845),
('parapsychologist', 0.5239906907081604),
('cringer', 0.5199892520904541),
('samir', 0.5048707127571106),
('reporter', 0.49532145261764526),
('swimmer', 0.4937909245491028),
('thrill-seeker', 0.4905340373516083),
('chiara', 0.48281964659690857),
('psychiatrist', 0.4788440763950348),
('nerd', 0.4779984951019287),
('surgeon', 0.47712844610214233),
('jock', 0.4741038382053375),
('geek', 0.4714686870574951),
('mumu', 0.47104766964912415),
('painter', 0.4689804017543793),
('cheater', 0.4655175805091858),
('hypnotist', 0.4645438492298126),
('whizz', 0.46407681703567505),
('cryptozoologist', 0.4627385437488556)]
[('bang-bang', 0.4289792478084564),
('master', 0.41190674901008606),
('greenleaf', 0.38207903504371643),
('122', 0.3811250925064087),
('fingernails', 0.3794997036457062),
('cardboard-cutout', 0.3740081787109375),
("album'", 0.3706256151199341),
('sex-starved', 0.3696949779987335),
('creme-de-la-creme', 0.36426788568496704),
('destroyed', 0.3638569116592407),
('imminent', 0.3612757921218872),
('cruisers', 0.3568859398365021),
("emo's", 0.35605981945991516),
('lavransdatter', 0.3534432649612427),
("'video'", 0.3508487641811371),
('garris', 0.3507363796234131),
('romanzo', 0.3495352268218994),
('tombes', 0.3494585454463959),
('story-writers', 0.3461073637008667),
('georgette', 0.34602558612823486)]
[('ex-marine', 0.5273298621177673),
('koichi', 0.5020822882652283),
('dorkish', 0.49750325083732605),
('fenyö', 0.4765225946903229),
('castleville', 0.46756264567375183),
('smoorenburg', 0.46484801173210144),
('chimp', 0.46456438302993774),
('swimmer', 0.46236276626586914),
('falcone', 0.4614230990409851),
('yak', 0.45991501212120056),
('gms', 0.4542686939239502),
('iván', 0.4503802955150604),
('spidy', 0.4494086503982544),
('arnie', 0.44659116864204407),
('hobo', 0.4465593695640564),
('evelyne', 0.4455353617668152),
('pandey', 0.4452363848686218),
('hector', 0.4442984461784363),
('baboon', 0.44382452964782715),
('miao', 0.4437481164932251)]

Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the dbow_words=1 initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task.

Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word vector training concurrent with doc vector training.)

Are the word vectors from this dataset any good at analogies?


In [15]:
# assuming something like
# https://word2vec.googlecode.com/svn/trunk/questions-words.txt 
# is in local directory
# note: this takes many minutes
for model in word_models:
    sections = model.accuracy('questions-words.txt')
    correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])
    print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))


Doc2Vec(dm/c,d100,n5,w5,mc2,s0.001,t4): 31.50% correct (3154 of 10012)
Doc2Vec(dbow,d100,n5,mc2,s0.001,t4): 0.00% correct (0 of 10012)
Doc2Vec(dm/m,d100,n5,w10,mc2,s0.001,t4): 32.24% correct (3228 of 10012)

Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/concat and DM/mean models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)

Slop


In [ ]:
This cell left intentionally erroneous.

To mix the Google dataset (if locally available) into the word tests...


In [ ]:
from gensim.models import KeyedVectors
w2v_g100b = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
w2v_g100b.compact_name = 'w2v_g100b'
word_models.append(w2v_g100b)

To get copious logging output from above steps...


In [ ]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)

To auto-reload python code while developing...


In [ ]:
%load_ext autoreload
%autoreload 2