Deep Learning

Assignment 5

The goal of this assignment is to train a Word2Vec skip-gram model over Text8 data.


In [1]:
# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
import os
import math
import random
import zipfile
import collections
import numpy as np
import seaborn as sns
import tensorflow as tf
from six.moves import range
from matplotlib import pylab
from sklearn.manifold import TSNE
from __future__ import print_function
from six.moves.urllib.request import urlretrieve

%matplotlib inline

Download the data from the source website if necessary.


In [2]:
url = 'http://mattmahoney.net/dc/'

def maybe_download(filename, expected_bytes):
    """Download a file if not present, and make sure it's the right size."""
    if not os.path.exists(filename):
        filename, _ = urlretrieve(url + filename, filename)
    statinfo = os.stat(filename)
    if statinfo.st_size == expected_bytes:
        print('Found and verified %s' % filename)
    else:
        print(statinfo.st_size)
        raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?')
    return filename

filename = maybe_download('text8.zip', 31344016)


Found and verified text8.zip

Read the data into a string.


In [3]:
def read_data(filename):
    """Extract the first file enclosed in a zip file as a list of words"""
    with zipfile.ZipFile(filename) as f:
        data = tf.compat.as_str(f.read(f.namelist()[0])).split()
    return data
  
words = read_data(filename)
print('Data size %d' % len(words))


Data size 17005207

Build the dictionary and replace rare words with UNK token.


In [4]:
vocabulary_size = 50000

def build_dataset(words):
    # find the count of vocb size most common words in corpus
    count = [['UNK', -1]]
    count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
    
    # create dictionary assigning number to each word in vocab
    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)
    
    # encode the corpus with the id assigned to the word
    data = list()
    unk_count = 0
    for word in words:
        # if the word in dict
        if word in dictionary:
            index = dictionary[word]
        # else label unknown
        else:
            index = 0  # dictionary['UNK']
            unk_count = unk_count + 1
        data.append(index)

    # update unknown words count
    count[0][1] = unk_count
    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 
    return data, count, dictionary, reverse_dictionary

# buildng dictionary
data, count, dictionary, reverse_dictionary = build_dataset(words)
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10])
del words  # Hint to reduce memory.


Most common words (+UNK) [['UNK', 418391], ('the', 1061396), ('of', 593677), ('and', 416629), ('one', 411764)]
Sample data [5239, 3084, 12, 6, 195, 2, 3137, 46, 59, 156]

Function to generate a training batch for the skip-gram model.


In [5]:
data_index = 0

def generate_batch(batch_size, num_skips, skip_window):
    global data_index
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window
  
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    
    # [skip_window target skip_window]
    span = 2*skip_window+1 
    # word buffer for generating vector
    buffer = collections.deque(maxlen=span)
    
    # read enough words for the span
    for _ in range(span):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)
  
    # now generate the skip gram model
    for i in range(batch_size // num_skips):
        # target label at the center of the buffer
        # index is skip_window
        target = skip_window  
        # targets that need to be avoided at index skip_window
        targets_to_avoid = [skip_window]
       
        # add values to batch and label variables
        for j in range(num_skips):
            while target in targets_to_avoid:
                target = random.randint(0, span - 1)
            targets_to_avoid.append(target)
            labels[i*num_skips+j,0] = buffer[target]
            batch[i*num_skips+j] = buffer[skip_window]
         
        # update the buffer
        buffer.append(data[data_index])
        data_index = (data_index+1)%len(data)
    return batch, labels

print('data:', [reverse_dictionary[di] for di in data[:8]])
for num_skips, skip_window in [(2, 1), (4, 2)]:
    data_index = 0
    batch, labels = generate_batch(batch_size=8, num_skips=num_skips, skip_window=skip_window)
    print('\nwith num_skips = %d and skip_window = %d:' % (num_skips, skip_window))
    print('    batch:', [reverse_dictionary[bi] for bi in batch])
    print('    labels:', [reverse_dictionary[li] for li in labels.reshape(8)])


data: ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first']

with num_skips = 2 and skip_window = 1:
    batch: ['originated', 'originated', 'as', 'as', 'a', 'a', 'term', 'term']
    labels: ['anarchism', 'as', 'originated', 'a', 'term', 'as', 'a', 'of']

with num_skips = 4 and skip_window = 2:
    batch: ['as', 'as', 'as', 'as', 'a', 'a', 'a', 'a']
    labels: ['term', 'anarchism', 'originated', 'a', 'term', 'of', 'as', 'originated']

Train a skip-gram model.


In [6]:
# How many times to reuse an input to generate a label.
num_skips = 2 
# How many words to consider left and right.
skip_window = 1 
# batch size for training
batch_size = 128
# Dimension of the embedding vector.
embedding_size = 128

# We pick a random validation set to sample nearest neighbors. here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent. 

# Random set of words to evaluate similarity on
valid_size = 16 
# Only pick dev samples in the head of the distribution.
valid_window = 100 
valid_examples = np.array(random.sample(range(valid_window), valid_size))

# Number of negative examples to sample.
num_sampled = 64 

# generate computation graph
graph = tf.Graph()
with graph.as_default(), tf.device('/cpu:0'):

    # Input data.
    train_dataset = tf.placeholder(tf.int32, shape=[batch_size])
    train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
    valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
  
    # Variables.
    embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
    softmax_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],
                                                      stddev=1.0 / math.sqrt(embedding_size)))
    softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))
  
    # Model.
    # Look up embeddings for inputs.
    embed = tf.nn.embedding_lookup(embeddings, train_dataset)
    # Compute the softmax loss, using a sample of the negative labels each time.
    loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(softmax_weights, softmax_biases, embed,
                                                     train_labels, num_sampled, vocabulary_size))


    # Optimizer.
    optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)
  
    # Compute the similarity between minibatch examples and all embeddings.
    # We use the cosine distance:
    norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
    normalized_embeddings = embeddings / norm
    
    valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)
    similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))

In [7]:
num_steps = 100001

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()
    print('Initialized')
    
    # main loop
    average_loss = 0
    for step in range(num_steps):
        # generating batch for training
        batch_data, batch_labels = generate_batch(batch_size, num_skips, skip_window)
        # feed dictionary for optimizer
        feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
        
        # run optimizer
        _, l = session.run([optimizer, loss], feed_dict=feed_dict)
        average_loss += l
        
        # print avg loss for every 2000 steps
        if step % 2000 == 0:
            if step > 0:
                average_loss = average_loss / 2000
            # The average loss is an estimate of the loss over the last 2000 batches.
            print('Average loss at step %d: %f' % (step, average_loss))
            average_loss = 0
    
        # note that this is expensive (~20% slowdown if computed every 500 steps)
        if step % 10000 == 0:
            sim = similarity.eval()            
            for i in range(valid_size):
                valid_word = reverse_dictionary[valid_examples[i]]
                top_k = 8 # number of nearest neighbors
                nearest = (-sim[i,:]).argsort()[1:top_k+1]
                log = 'Nearest to %s:' % valid_word
                for k in range(top_k):
                    close_word = reverse_dictionary[nearest[k]]
                    log = '%s %s,' % (log, close_word)
                print(log)
    
    final_embeddings = normalized_embeddings.eval()


Initialized
Average loss at step 0: 7.833626
Nearest to may: worm, disguised, favre, fasa, uproar, disprove, mellow, andre,
Nearest to years: igor, ucd, inked, downplay, rapport, mpa, guantanamo, dato,
Nearest to one: mardi, crossfire, broth, jafar, waley, specifications, northampton, biography,
Nearest to with: bigcup, landmark, harper, bw, bbl, alleged, sahara, aedui,
Nearest to were: goncourt, motte, anglophones, romantics, abandoning, fordham, slipping, original,
Nearest to time: informants, compensating, simulation, remediation, oaxaca, fistula, deut, rampage,
Nearest to his: differing, suites, racecar, stat, chaired, apostles, caches, deccan,
Nearest to for: molecule, khazars, unmolested, flashing, applets, vero, taoist, oligarchy,
Nearest to was: bonanza, figure, felonies, stretching, singular, cleopatra, cana, indie,
Nearest to i: called, netanyahu, academia, mesocricetus, deprogramming, kino, lenoir, alligators,
Nearest to many: theaters, sandalwood, spamalot, incorporates, camilla, farmland, comfort, donning,
Nearest to while: pennsylvania, gabal, mcvie, denoted, astrophysicist, palpatine, decapitate, wan,
Nearest to this: snipers, tunnels, storage, blueprint, molson, principe, truetype, dalmatia,
Nearest to on: rhymes, eastbound, dacia, chs, aztec, misty, waveforms, conner,
Nearest to only: wilhelmina, davids, rai, ramjet, rewritten, bagdad, excommunicating, byron,
Nearest to these: vtol, fauna, noodle, adenosine, illuminism, dos, flammable, twelve,
Average loss at step 2000: 4.362838
Average loss at step 4000: 3.864067
Average loss at step 6000: 3.788428
Average loss at step 8000: 3.681827
Average loss at step 10000: 3.613047
Nearest to may: can, would, could, disguised, nac, disprove, corse, might,
Nearest to years: igor, overrun, reside, inked, guantanamo, utc, mpa, lessons,
Nearest to one: two, six, three, seven, eight, five, four, nine,
Nearest to with: for, offences, in, by, ferris, duet, and, fail,
Nearest to were: are, have, was, fairest, modus, had, thrusters, abandoning,
Nearest to time: era, oaxaca, madam, currying, confer, admirer, vincenzo, rampage,
Nearest to his: their, its, her, s, the, studi, yastrzemski, tripled,
Nearest to for: with, of, molecule, burnt, juniors, erbium, interpolating, hays,
Nearest to was: is, has, had, were, by, be, are, became,
Nearest to i: netanyahu, rask, kino, ecclesia, parabolic, inspect, inductive, pulmonic,
Nearest to many: some, incorporates, rakis, minority, theaters, sandalwood, ssang, rohypnol,
Nearest to while: pennsylvania, palpatine, gabal, clary, prokaryote, though, ayres, kleene,
Nearest to this: it, the, which, that, khalid, molson, a, any,
Nearest to on: in, by, martingale, chs, aggravating, arica, waveforms, at,
Nearest to only: funakoshi, byron, iblis, woody, chairmanship, bidder, matchbox, being,
Nearest to these: promulgation, vtol, such, reason, international, alley, fauna, signing,
Average loss at step 12000: 3.604215
Average loss at step 14000: 3.570375
Average loss at step 16000: 3.410614
Average loss at step 18000: 3.455217
Average loss at step 20000: 3.541226
Nearest to may: can, would, could, will, should, might, must, corse,
Nearest to years: days, willamette, overrun, gardnerian, utc, reside, flagg, igor,
Nearest to one: two, four, three, seven, five, eight, six, nine,
Nearest to with: in, between, by, for, nmr, examine, nernst, rescued,
Nearest to were: are, was, had, have, by, is, be, been,
Nearest to time: year, era, confer, council, oaxaca, mucus, currying, settled,
Nearest to his: their, its, her, the, tripled, my, this, affiliates,
Nearest to for: einer, with, kelsey, of, fanatic, molecular, int, in,
Nearest to was: is, were, has, had, became, be, are, but,
Nearest to i: netanyahu, kino, we, ii, bashir, rask, cheerfully, stevenson,
Nearest to many: some, several, these, all, other, rakis, various, oneida,
Nearest to while: pennsylvania, but, though, and, palpatine, sniff, prokaryote, became,
Nearest to this: it, which, any, anisotropies, the, another, khalid, maga,
Nearest to on: in, at, broadcasters, upon, during, through, alem, patanjali,
Nearest to only: mimi, funakoshi, renderer, iblis, scourge, jehoahaz, byron, ep,
Nearest to these: such, some, many, they, all, which, vtol, promulgation,
Average loss at step 22000: 3.505056
Average loss at step 24000: 3.489106
Average loss at step 26000: 3.481616
Average loss at step 28000: 3.475496
Average loss at step 30000: 3.502128
Nearest to may: would, can, will, could, must, should, might, damietta,
Nearest to years: days, year, times, willamette, gardnerian, centuries, flagg, overrun,
Nearest to one: two, four, seven, eight, six, five, three, nine,
Nearest to with: between, for, examine, among, in, by, wallach, ferris,
Nearest to were: are, was, have, had, been, by, is, fairest,
Nearest to time: year, approach, admirer, era, way, period, process, synthetically,
Nearest to his: their, her, its, my, the, s, tripled, tries,
Nearest to for: with, of, after, einer, dreaming, towards, burnt, documentary,
Nearest to was: is, were, had, has, became, been, by, being,
Nearest to i: t, ii, netanyahu, we, iii, forever, she, stevenson,
Nearest to many: some, several, these, various, all, both, mechanics, oneida,
Nearest to while: but, though, pennsylvania, is, prokaryote, clary, when, cadet,
Nearest to this: it, which, another, anisotropies, that, any, usually, trypanosomiasis,
Nearest to on: in, upon, through, under, at, rota, asymmetrical, chs,
Nearest to only: renderer, loki, mimi, demand, still, nomination, iblis, chairmanship,
Nearest to these: many, some, such, they, all, both, which, usher,
Average loss at step 32000: 3.502114
Average loss at step 34000: 3.494200
Average loss at step 36000: 3.454283
Average loss at step 38000: 3.299891
Average loss at step 40000: 3.427658
Nearest to may: can, would, will, could, must, should, might, shall,
Nearest to years: days, times, decades, centuries, year, months, willamette, utc,
Nearest to one: seven, four, six, three, eight, five, two, nine,
Nearest to with: between, by, wallach, including, glean, bigcup, tzara, vit,
Nearest to were: are, was, have, had, been, be, those, being,
Nearest to time: year, way, period, day, settled, rohrabacher, lowercase, criticising,
Nearest to his: their, her, its, my, s, the, our, lois,
Nearest to for: when, in, if, of, hays, beach, jesuit, bua,
Nearest to was: is, were, became, had, has, when, been, be,
Nearest to i: t, we, ii, she, they, you, he, netanyahu,
Nearest to many: some, several, these, various, both, those, transliteration, mistakenly,
Nearest to while: and, but, when, though, are, although, vehemently, cadet,
Nearest to this: which, that, it, another, the, anisotropies, dine, any,
Nearest to on: upon, about, at, in, against, during, durch, friar,
Nearest to only: renderer, richer, chef, funakoshi, slippage, it, iblis, supposes,
Nearest to these: many, some, such, both, several, they, those, which,
Average loss at step 42000: 3.433029
Average loss at step 44000: 3.447810
Average loss at step 46000: 3.450117
Average loss at step 48000: 3.349537
Average loss at step 50000: 3.382025
Nearest to may: can, would, will, could, should, must, might, shall,
Nearest to years: days, times, decades, centuries, months, year, willamette, utc,
Nearest to one: eight, six, two, four, seven, nine, five, three,
Nearest to with: between, flips, bigcup, beeb, against, landmark, in, among,
Nearest to were: are, was, have, had, be, been, those, being,
Nearest to time: year, period, way, day, criticising, settled, lowercase, mucus,
Nearest to his: their, her, its, my, the, your, s, theo,
Nearest to for: while, patterned, by, hays, when, including, towards, lenoir,
Nearest to was: is, has, became, had, were, be, by, recalls,
Nearest to i: ii, you, we, t, cheerfully, adjoined, forever, she,
Nearest to many: some, several, these, both, various, those, transliteration, all,
Nearest to while: when, however, but, although, though, and, where, after,
Nearest to this: which, it, he, another, anisotropies, dine, the, some,
Nearest to on: upon, in, at, friar, over, around, about, waveforms,
Nearest to only: renderer, until, ep, chef, still, dimensionality, slippage, limes,
Nearest to these: many, some, both, such, several, those, they, all,
Average loss at step 52000: 3.432552
Average loss at step 54000: 3.425564
Average loss at step 56000: 3.439667
Average loss at step 58000: 3.392906
Average loss at step 60000: 3.396072
Nearest to may: can, would, will, could, should, must, might, shall,
Nearest to years: days, decades, times, centuries, months, year, mobilizing, mi,
Nearest to one: four, eight, seven, six, two, three, nine, five,
Nearest to with: between, including, bigcup, among, when, by, in, toro,
Nearest to were: are, had, was, have, be, those, been, while,
Nearest to time: year, way, period, viewpoint, day, mucus, criticising, sensed,
Nearest to his: her, their, its, my, your, gn, the, pms,
Nearest to for: einer, fanatic, practical, of, inflows, when, while, kelsey,
Nearest to was: is, had, became, were, has, although, when, did,
Nearest to i: we, you, ii, t, adjoined, nebula, cheerfully, they,
Nearest to many: some, several, these, various, both, most, transliteration, those,
Nearest to while: although, when, though, after, before, including, capacitive, where,
Nearest to this: it, which, the, there, that, some, another, itself,
Nearest to on: upon, through, in, rota, asymmetrical, hutcheson, against, toyota,
Nearest to only: until, renderer, still, actually, chef, always, also, dimensionality,
Nearest to these: many, some, several, those, such, both, all, they,
Average loss at step 62000: 3.238528
Average loss at step 64000: 3.258838
Average loss at step 66000: 3.398573
Average loss at step 68000: 3.393625
Average loss at step 70000: 3.358500
Nearest to may: can, would, will, could, should, must, might, shall,
Nearest to years: days, decades, months, times, centuries, year, weeks, mi,
Nearest to one: six, two, seven, four, eight, nine, five, three,
Nearest to with: between, among, prejudiced, flips, denslow, when, in, soho,
Nearest to were: are, was, had, have, these, although, while, be,
Nearest to time: year, period, way, wendy, day, viewpoint, legality, mucus,
Nearest to his: her, their, its, my, your, gn, our, pms,
Nearest to for: in, while, cosmologists, einer, hauling, to, of, goidelic,
Nearest to was: is, had, has, were, became, been, does, be,
Nearest to i: ii, you, we, t, adjoined, yard, sq, g,
Nearest to many: some, several, these, various, all, both, most, transliteration,
Nearest to while: although, when, though, where, paladins, before, if, and,
Nearest to this: it, which, the, that, another, there, any, some,
Nearest to on: upon, in, through, rota, within, asymmetrical, at, against,
Nearest to only: still, actually, until, appleseed, always, renderer, dimensionality, skew,
Nearest to these: many, such, some, those, several, are, various, they,
Average loss at step 72000: 3.370355
Average loss at step 74000: 3.349001
Average loss at step 76000: 3.310726
Average loss at step 78000: 3.351032
Average loss at step 80000: 3.377399
Nearest to may: can, would, will, could, must, should, might, shall,
Nearest to years: days, decades, months, times, centuries, year, weeks, ways,
Nearest to one: seven, six, two, eight, four, three, five, nine,
Nearest to with: between, among, when, flips, denslow, giraffes, including, vit,
Nearest to were: are, had, was, have, including, been, be, being,
Nearest to time: year, period, day, success, place, freezing, century, stuck,
Nearest to his: her, their, its, my, your, our, pms, gn,
Nearest to for: hays, when, if, while, einer, against, cosmologists, practical,
Nearest to was: is, became, were, has, had, been, when, be,
Nearest to i: ii, you, we, t, iv, iii, adjoined, g,
Nearest to many: some, several, various, these, most, both, all, those,
Nearest to while: although, though, when, after, before, paladins, capacitive, but,
Nearest to this: which, it, another, itself, there, the, some, richness,
Nearest to on: upon, in, through, at, abrahamic, pliny, during, against,
Nearest to only: actually, always, poudre, late, still, until, probably, last,
Nearest to these: many, those, several, various, both, some, such, all,
Average loss at step 82000: 3.406727
Average loss at step 84000: 3.410662
Average loss at step 86000: 3.392188
Average loss at step 88000: 3.352461
Average loss at step 90000: 3.362711
Nearest to may: can, could, would, will, must, should, might, cannot,
Nearest to years: days, decades, months, times, weeks, centuries, year, minutes,
Nearest to one: seven, four, eight, two, six, five, three, nine,
Nearest to with: between, among, including, bigcup, vit, by, prejudiced, weaver,
Nearest to were: are, was, had, have, while, although, been, being,
Nearest to time: year, period, day, week, lemonade, if, synthetically, eriksson,
Nearest to his: her, their, its, my, your, our, the, wakefulness,
Nearest to for: including, after, of, while, hays, when, einer, during,
Nearest to was: is, had, were, became, has, be, been, chukotka,
Nearest to i: ii, t, we, you, iv, g, iii, forbes,
Nearest to many: some, several, these, various, all, most, those, both,
Nearest to while: although, when, before, though, after, but, paladins, were,
Nearest to this: it, which, some, the, richness, another, itself, he,
Nearest to on: upon, under, at, against, through, in, budge, abrahamic,
Nearest to only: actually, poudre, still, even, dimensionality, no, skew, always,
Nearest to these: many, some, several, are, various, those, such, all,
Average loss at step 92000: 3.394156
Average loss at step 94000: 3.254331
Average loss at step 96000: 3.350147
Average loss at step 98000: 3.238525
Average loss at step 100000: 3.357658
Nearest to may: can, could, must, should, would, will, might, cannot,
Nearest to years: days, decades, times, months, centuries, weeks, year, minutes,
Nearest to one: two, four, seven, eight, three, nine, six, five,
Nearest to with: between, including, among, vit, when, throughout, in, toro,
Nearest to were: are, was, have, had, including, although, these, those,
Nearest to time: year, week, period, way, reason, day, lemonade, synthetically,
Nearest to his: her, their, your, its, my, the, our, s,
Nearest to for: when, patterned, after, before, including, while, unsurprisingly, if,
Nearest to was: is, became, had, has, were, although, tremendously, been,
Nearest to i: ii, we, you, t, iv, iii, they, adjoined,
Nearest to many: several, some, various, these, few, numerous, those, all,
Nearest to while: when, although, though, before, paladins, are, where, if,
Nearest to this: which, it, another, itself, some, he, the, richness,
Nearest to on: upon, in, through, under, at, against, during, budge,
Nearest to only: still, actually, even, supposes, not, along, never, until,
Nearest to these: several, some, many, various, those, all, are, both,

In [8]:
num_points = 400

tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=8000)
two_d_embeddings = tsne.fit_transform(final_embeddings[1:num_points+1,:])

In [9]:
def plot(embeddings, labels):
    assert embeddings.shape[0] >= len(labels), 'More labels than embeddings'
    pylab.figure(figsize=(15,15))  # in inches
    for i, label in enumerate(labels):
        x, y = embeddings[i,:]
        pylab.scatter(x, y)
        pylab.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',ha='right', va='bottom')
    pylab.show()

words = [reverse_dictionary[i] for i in range(1, num_points+1)]
plot(two_d_embeddings, words)


/usr/lib/pymodules/python2.7/matplotlib/collections.py:548: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  if self._edgecolors == 'face':

Problem

An alternative to skip-gram is another Word2Vec model called CBOW (Continuous Bag of Words). In the CBOW model, instead of predicting a context word from a word vector, you predict a word from the sum of all the word vectors in its context. Implement and evaluate a CBOW model trained on the text8 dataset.



In [ ]:
data_index = 0

def generate_batch(batch_size, num_skips, skip_window):
    global data_index
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window
  
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    
    # [skip_window target skip_window]
    span = 2*skip_window+1 
    # word buffer for generating vector
    buffer = collections.deque(maxlen=span)
    
    # read enough words for the span
    for _ in range(span):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)
  
    # now generate the skip gram model
    for i in range(batch_size // num_skips):
        # target label at the center of the buffer
        # index is skip_window
        target = skip_window  
        # targets that need to be avoided at index skip_window
        targets_to_avoid = [skip_window]
        
        # add values to batch and label variables
        for j in range(num_skips):
            while target in targets_to_avoid:
                target = random.randint(0, span - 1)
            targets_to_avoid.append(target)
            labels[i*num_skips+j,0] = buffer[target]
            batch[i*num_skips+j] = buffer[skip_window]
         
        # update the buffer
        buffer.append(data[data_index])
        data_index = (data_index+1)%len(data)
    return batch, labels

print('data:', [reverse_dictionary[di] for di in data[:8]])
for num_skips, skip_window in [(2, 1), (4, 2)]:
    data_index = 0
    batch, labels = generate_batch(batch_size=8, num_skips=num_skips, skip_window=skip_window)
    print('\nwith num_skips = %d and skip_window = %d:' % (num_skips, skip_window))
    print('    batch:', [reverse_dictionary[bi] for bi in batch])
    print('    labels:', [reverse_dictionary[li] for li in labels.reshape(8)])

In [ ]:
# How many times to reuse an input to generate a label.
num_skips = 2 
# How many words to consider left and right.
skip_window = 1 
# batch size for training
batch_size = 128
# Dimension of the embedding vector.
embedding_size = 128

# We pick a random validation set to sample nearest neighbors. here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent. 

# Random set of words to evaluate similarity on
valid_size = 16 
# Only pick dev samples in the head of the distribution.
valid_window = 100 
valid_examples = np.array(random.sample(range(valid_window), valid_size))

# Number of negative examples to sample.
num_sampled = 64 

# generate computation graph
graph = tf.Graph()
with graph.as_default(), tf.device('/cpu:0'):

    # Input data.
    train_dataset = tf.placeholder(tf.int32, shape=[batch_size])
    train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
    valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
  
    # Variables.
    embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
    softmax_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],
                                                      stddev=1.0 / math.sqrt(embedding_size)))
    softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))
  
    # Model.
    # Look up embeddings for inputs.
    embed = tf.nn.embedding_lookup(embeddings, train_dataset)
    # Compute the softmax loss, using a sample of the negative labels each time.
    loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(softmax_weights, softmax_biases, embed,
                                                     train_labels, num_sampled, vocabulary_size))


    # Optimizer.
    optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)
  
    # Compute the similarity between minibatch examples and all embeddings.
    # We use the cosine distance:
    norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
    normalized_embeddings = embeddings / norm
    
    valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)
    similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))

In [ ]:
num_steps = 100001

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()
    print('Initialized')
    
    # main loop
    average_loss = 0
    for step in range(num_steps):
        # generating batch for training
        batch_data, batch_labels = generate_batch(batch_size, num_skips, skip_window)
        # feed dictionary for optimizer
        feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
        
        # run optimizer
        _, l = session.run([optimizer, loss], feed_dict=feed_dict)
        average_loss += l
        
        # print avg loss for every 2000 steps
        if step % 2000 == 0:
            if step > 0:
                average_loss = average_loss / 2000
            # The average loss is an estimate of the loss over the last 2000 batches.
            print('Average loss at step %d: %f' % (step, average_loss))
            average_loss = 0
    
        # note that this is expensive (~20% slowdown if computed every 500 steps)
        if step % 10000 == 0:
            sim = similarity.eval()            
            for i in range(valid_size):
                valid_word = reverse_dictionary[valid_examples[i]]
                top_k = 8 # number of nearest neighbors
                nearest = (-sim[i,:]).argsort()[1:top_k+1]
                log = 'Nearest to %s:' % valid_word
                for k in range(top_k):
                    close_word = reverse_dictionary[nearest[k]]
                    log = '%s %s,' % (log, close_word)
                print(log)
    
    final_embeddings = normalized_embeddings.eval()