Word Mover's Distance


In [1]:
import pickle
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import pairwise_distances
import numpy as np

In [5]:
# Load EOS processed corpus
doc_filepath = 'data/eos/eos_tokenize_all.p'
eos_corpus = pickle.load( open( doc_filepath, "rb" ) )

In [2]:
%%time

word2vec_model_file = 'data/eos/word2vec_model'

if 0 == 1:
    # Train Word2Vec on all the restaurants.
    word2vec_model = Word2Vec(eos_corpus, workers=7, size=100)
    word2vec_model.init_sims(replace=True)  # Normalizes the vectors in the word2vec class.
    word2vec_model.save(word2vec_model_file)

else:
    word2vec_model = Word2Vec.load(word2vec_model_file)  # you can continue training with the loaded model!


CPU times: user 1.13 s, sys: 84 ms, total: 1.21 s
Wall time: 1.62 s

In [3]:
word2vec_model.most_similar('obama')


Out[3]:
[('barack', 0.7637264728546143),
 ('bush', 0.6618968844413757),
 ('puterilor', 0.6543050408363342),
 ('erlassenen', 0.6493631601333618),
 ('gleichsetzen', 0.6100968718528748),
 ('castro', 0.5896024703979492),
 ('abe', 0.578781008720398),
 ('earnest', 0.5780454874038696),
 ('clinton', 0.5418081879615784),
 ('jokowi', 0.5362154245376587)]

In [6]:
%time


# Initialize WmdSimilarity.
from gensim.similarities import WmdSimilarity
num_best = 10
instance = WmdSimilarity(eos_corpus, word2vec_model, num_best=10)


CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 3.34 µs

In [8]:
## Not scalable!!!
A = np.array([[i] for i in range(len(eos_corpus))])

A = np.array([[i] for i in range(1000)])

def f(x, y):
    return word2vec_model.wmdistance(eos_corpus[int(x)], eos_corpus[int(y)])

X_wmd_distance = pairwise_distances(A, metric=f, n_jobs=-1)


---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-8-0f1430fa2555> in <module>()
      7     return word2vec_model.wmdistance(eos_corpus[int(x)], eos_corpus[int(y)])
      8 
----> 9 X_wmd_distance = pairwise_distances(A, metric=f, n_jobs=-1)

/usr/local/lib/python3.5/dist-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, **kwds)
   1238         func = partial(distance.cdist, metric=metric, **kwds)
   1239 
-> 1240     return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1241 
   1242 

/usr/local/lib/python3.5/dist-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1087     ret = Parallel(n_jobs=n_jobs, verbose=0)(
   1088         fd(X, Y[s], **kwds)
-> 1089         for s in gen_even_slices(Y.shape[0], n_jobs))
   1090 
   1091     return np.hstack(ret)

/usr/local/lib/python3.5/dist-packages/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
    766                 # consumption.
    767                 self._iterating = False
--> 768             self.retrieve()
    769             # Make sure that we get a last message telling us we are done
    770             elapsed_time = time.time() - self._start_time

/usr/local/lib/python3.5/dist-packages/sklearn/externals/joblib/parallel.py in retrieve(self)
    717                     ensure_ready = self._managed_backend
    718                     backend.abort_everything(ensure_ready=ensure_ready)
--> 719                 raise exception
    720 
    721     def __call__(self, iterable):

/usr/local/lib/python3.5/dist-packages/sklearn/externals/joblib/parallel.py in retrieve(self)
    680                 # check if timeout supported in backend future implementation
    681                 if 'timeout' in getfullargspec(job.get).args:
--> 682                     self._output.extend(job.get(timeout=self.timeout))
    683                 else:
    684                     self._output.extend(job.get())

/usr/lib/python3.5/multiprocessing/pool.py in get(self, timeout)
    600 
    601     def get(self, timeout=None):
--> 602         self.wait(timeout)
    603         if not self.ready():
    604             raise TimeoutError

/usr/lib/python3.5/multiprocessing/pool.py in wait(self, timeout)
    597 
    598     def wait(self, timeout=None):
--> 599         self._event.wait(timeout)
    600 
    601     def get(self, timeout=None):

/usr/lib/python3.5/threading.py in wait(self, timeout)
    547             signaled = self._flag
    548             if not signaled:
--> 549                 signaled = self._cond.wait(timeout)
    550             return signaled
    551 

/usr/lib/python3.5/threading.py in wait(self, timeout)
    291         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    292             if timeout is None:
--> 293                 waiter.acquire()
    294                 gotit = True
    295             else:

KeyboardInterrupt: 

In [ ]:
plt.figure(4)
plt.imshow(X_wmd_distance); 
plt.title('WMD Similarity of EOS data')
plt.colorbar()
plt.show()

In [ ]: