BLSLDA Profiling


In [1]:
from modules.helpers import plot_images
from functools import partial
from sklearn.metrics import (roc_auc_score, roc_curve)
import seaborn as sns

imshow = partial(plt.imshow, cmap='gray', interpolation='nearest', aspect='auto')
sns.set(style='white')

Generate document-term matrix


In [2]:
V = 100
K = 10
N = 100
D = 1000
alpha = np.repeat(1., K)
beta = np.repeat(1., V)

In [3]:
# generate phi
phi = np.random.RandomState(42).dirichlet(beta, size=K)
imshow(phi)


Out[3]:
<matplotlib.image.AxesImage at 0x114ff8940>

In [4]:
# generate theta
theta = np.random.RandomState(42).dirichlet(alpha, size=D)
imshow(theta)


Out[4]:
<matplotlib.image.AxesImage at 0x1151386d8>

In [5]:
# generate document-term matrix
doc_term_matrix = np.zeros((D, V), dtype=np.int64)
rng = np.random.RandomState(42)
for d in range(D):
    topic_histogram = rng.multinomial(N, theta[d])
    for k in range(K):
        doc_term_matrix[d] += rng.multinomial(topic_histogram[k], phi[k])
imshow(doc_term_matrix)


Out[5]:
<matplotlib.image.AxesImage at 0x1151a0a58>

Generate responses


In [6]:
# choose parameter values
mu = 0.
nu2 = 1.
eta = np.random.RandomState(4).normal(loc=mu, scale=nu2, size=K)
pd.Series(eta).plot(kind='bar')


Out[6]:
<matplotlib.axes._subplots.AxesSubplot at 0x115176518>

In [7]:
# plot histogram of pre-responses
zeta = np.array([np.dot(eta, theta[d]) for d in range(D)])
pd.Series(zeta).hist(bins=50)


Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x115d8e780>

In [8]:
# set y
y = (zeta >= 0).astype(int)
# plot histogram of responses
pd.Series(y).hist()
plt.title('positive examples {} ({:.2f}%)'.format(y.sum(), y.sum() / D * 100))


Out[8]:
<matplotlib.text.Text at 0x115ee9eb8>

Estimate parameters


In [9]:
import pstats, cProfile
from slda.topic_models import BLSLDA

In [10]:
_K = 10
_alpha = alpha
_beta = np.repeat(0.01, V)
_mu = mu
_nu2 = nu2
_b = 8
n_iter = 200
blslda = BLSLDA(_K, _alpha, _beta, _mu, _nu2, _b, n_iter, seed=42)

In [11]:
%%time
cProfile.runctx("blslda.fit(doc_term_matrix, y)", globals(), locals(), "Profile.prof")


2015-08-25 23:01:53.929901 start iterations
2015-08-25 23:01:54.410328 0:00:00.480427 elapsed, iter   10, LL -250839.0493, 13.35% change from last
2015-08-25 23:01:54.839207 0:00:00.909306 elapsed, iter   20, LL -238917.2485, 4.75% change from last
2015-08-25 23:01:55.257168 0:00:01.327267 elapsed, iter   30, LL -231327.6734, 3.18% change from last
2015-08-25 23:01:55.682587 0:00:01.752686 elapsed, iter   40, LL -225735.7314, 2.42% change from last
2015-08-25 23:01:56.117728 0:00:02.187827 elapsed, iter   50, LL -220733.3096, 2.22% change from last
2015-08-25 23:01:56.540373 0:00:02.610472 elapsed, iter   60, LL -216555.1285, 1.89% change from last
2015-08-25 23:01:56.961528 0:00:03.031627 elapsed, iter   70, LL -212642.5964, 1.81% change from last
2015-08-25 23:01:57.383279 0:00:03.453378 elapsed, iter   80, LL -211006.7754, 0.77% change from last
2015-08-25 23:01:57.815986 0:00:03.886085 elapsed, iter   90, LL -208124.3425, 1.37% change from last
2015-08-25 23:01:58.244966 0:00:04.315065 elapsed, iter  100, LL -205301.7756, 1.36% change from last
2015-08-25 23:01:58.667271 0:00:04.737370 elapsed, iter  110, LL -203570.5584, 0.84% change from last
2015-08-25 23:01:59.092978 0:00:05.163077 elapsed, iter  120, LL -200980.3643, 1.27% change from last
2015-08-25 23:01:59.511600 0:00:05.581699 elapsed, iter  130, LL -200127.0719, 0.42% change from last
2015-08-25 23:01:59.932932 0:00:06.003031 elapsed, iter  140, LL -198705.7298, 0.71% change from last
2015-08-25 23:02:00.360819 0:00:06.430918 elapsed, iter  150, LL -197738.3814, 0.49% change from last
2015-08-25 23:02:00.787860 0:00:06.857959 elapsed, iter  160, LL -197591.5563, 0.07% change from last
2015-08-25 23:02:01.228314 0:00:07.298413 elapsed, iter  170, LL -196815.0336, 0.39% change from last
2015-08-25 23:02:01.667865 0:00:07.737964 elapsed, iter  180, LL -195774.9359, 0.53% change from last
2015-08-25 23:02:02.093168 0:00:08.163267 elapsed, iter  190, LL -194476.6685, 0.66% change from last
CPU times: user 8.51 s, sys: 40.1 ms, total: 8.55 s
Wall time: 8.56 s

In [12]:
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("tottime").print_stats()


Tue Aug 25 23:02:02 2015    Profile.prof

         20018581 function calls in 8.556 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    7.086    7.086    8.550    8.550 _topic_models.pyx:497(gibbs_sampler_blslda)
 20000000    1.295    0.000    1.295    0.000 _topic_models.pyx:64(searchsorted)
      200    0.131    0.001    0.131    0.001 _topic_models.pyx:151(loglikelihood_blslda)
        1    0.008    0.008    0.008    0.008 _topic_models.pyx:33(create_rands)
      200    0.005    0.000    0.012    0.000 linalg.py:296(solve)
     1015    0.004    0.000    0.004    0.000 {built-in method array}
        1    0.002    0.002    0.002    0.002 {method 'nonzero' of 'numpy.ndarray' objects}
        8    0.002    0.000    0.002    0.000 {method 'repeat' of 'numpy.ndarray' objects}
        1    0.002    0.002    0.002    0.002 _topic_models.pyx:48(create_topic_lookup)
      811    0.002    0.000    0.002    0.000 stringsource:956(memoryview_fromslice)
     1007    0.002    0.000    0.005    0.000 numeric.py:394(asarray)
     1229    0.001    0.000    0.001    0.000 stringsource:317(__cinit__)
      200    0.001    0.000    0.002    0.000 linalg.py:139(_commonType)
      200    0.001    0.000    0.008    0.000 _topic_models.pyx:214(print_progress)
      400    0.001    0.000    0.002    0.000 linalg.py:106(_makearray)
      200    0.001    0.000    0.001    0.000 {method 'astype' of 'numpy.ndarray' objects}
      200    0.001    0.000    0.001    0.000 linalg.py:209(_assertNdSquareness)
       76    0.001    0.000    0.001    0.000 encoder.py:197(iterencode)
      200    0.001    0.000    0.001    0.000 linalg.py:198(_assertRankAtLeast2)
      200    0.000    0.000    0.000    0.000 linalg.py:101(get_linalg_error_extobj)
      133    0.000    0.000    0.000    0.000 {method 'send' of 'zmq.backend.cython.socket.Socket' objects}
     1225    0.000    0.000    0.000    0.000 stringsource:339(__dealloc__)
     1000    0.000    0.000    0.000    0.000 {built-in method issubclass}
      494    0.000    0.000    0.000    0.000 {method 'sub' of '_sre.SRE_Pattern' objects}
        1    0.000    0.000    0.005    0.005 topic_models.py:29(_create_lookups)
      494    0.000    0.000    0.001    0.000 encoder.py:33(encode_basestring)
       19    0.000    0.000    0.000    0.000 {zmq.backend.cython._poll.zmq_poll}
      400    0.000    0.000    0.000    0.000 linalg.py:124(_realType)
      600    0.000    0.000    0.000    0.000 linalg.py:111(isComplexType)
       19    0.000    0.000    0.005    0.000 session.py:589(send)
       40    0.000    0.000    0.007    0.000 iostream.py:207(write)
      418    0.000    0.000    0.001    0.000 stringsource:613(memoryview_cwrapper)
       76    0.000    0.000    0.002    0.000 __init__.py:182(dumps)
       19    0.000    0.000    0.007    0.000 iostream.py:151(flush)
      421    0.000    0.000    0.000    0.000 {built-in method getattr}
       19    0.000    0.000    0.001    0.000 uuid.py:596(uuid4)
      247    0.000    0.000    0.000    0.000 traitlets.py:395(__get__)
       19    0.000    0.000    0.003    0.000 session.py:530(serialize)
      228    0.000    0.000    0.000    0.000 {built-in method max}
       19    0.000    0.000    0.000    0.000 {built-in method urandom}
       19    0.000    0.000    0.000    0.000 uuid.py:104(__init__)
       76    0.000    0.000    0.002    0.000 encoder.py:175(encode)
       76    0.000    0.000    0.002    0.000 jsonapi.py:31(dumps)
        3    0.000    0.000    0.000    0.000 {method 'reduce' of 'numpy.ufunc' objects}
       19    0.000    0.000    0.001    0.000 session.py:515(sign)
       19    0.000    0.000    0.001    0.000 socket.py:250(send_multipart)
      419    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
       19    0.000    0.000    0.001    0.000 iostream.py:123(_flush_from_subprocesses)
       19    0.000    0.000    0.001    0.000 session.py:496(msg)
       19    0.000    0.000    0.000    0.000 attrsettr.py:35(__getattr__)
      421    0.000    0.000    0.000    0.000 {built-in method isinstance}
       76    0.000    0.000    0.002    0.000 session.py:84(<lambda>)
        2    0.000    0.000    0.000    0.000 _topic_models.pyx:230(estimate_matrix)
      807    0.000    0.000    0.000    0.000 stringsource:508(__get__)
       19    0.000    0.000    0.000    0.000 uuid.py:230(__str__)
       76    0.000    0.000    0.000    0.000 encoder.py:98(__init__)
      355    0.000    0.000    0.000    0.000 {built-in method len}
       19    0.000    0.000    0.001    0.000 session.py:493(msg_header)
     1007    0.000    0.000    0.000    0.000 stringsource:468(__getbuffer__)
       19    0.000    0.000    0.001    0.000 session.py:441(msg_id)
      809    0.000    0.000    0.000    0.000 stringsource:932(__dealloc__)
       95    0.000    0.000    0.000    0.000 {method 'update' of '_hashlib.HASH' objects}
       19    0.000    0.000    0.000    0.000 hmac.py:95(copy)
       21    0.000    0.000    0.000    0.000 {built-in method hasattr}
       19    0.000    0.000    0.000    0.000 session.py:200(extract_header)
       78    0.000    0.000    0.000    0.000 iostream.py:93(_is_master_process)
       59    0.000    0.000    0.000    0.000 iostream.py:102(_check_mp_mode)
      422    0.000    0.000    0.000    0.000 stringsource:619(memoryview_check)
       19    0.000    0.000    0.000    0.000 iostream.py:96(_is_master_thread)
       19    0.000    0.000    0.000    0.000 iostream.py:238(_flush_buffer)
       19    0.000    0.000    0.000    0.000 {built-in method now}
       19    0.000    0.000    0.000    0.000 poll.py:77(poll)
       19    0.000    0.000    0.000    0.000 iostream.py:247(_new_buffer)
      200    0.000    0.000    0.000    0.000 {built-in method min}
        1    0.000    0.000    8.556    8.556 <string>:1(<module>)
       95    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        3    0.000    0.000    0.000    0.000 shape_base.py:792(tile)
       19    0.000    0.000    0.000    0.000 {method 'isoformat' of 'datetime.datetime' objects}
       19    0.000    0.000    0.000    0.000 session.py:655(<listcomp>)
       57    0.000    0.000    0.000    0.000 {method 'copy' of '_hashlib.HASH' objects}
       19    0.000    0.000    0.000    0.000 threading.py:1230(current_thread)
       19    0.000    0.000    0.000    0.000 session.py:195(msg_header)
       19    0.000    0.000    0.000    0.000 hmac.py:108(_current)
       19    0.000    0.000    0.000    0.000 {built-in method from_bytes}
       19    0.000    0.000    0.000    0.000 encoder.py:37(replace)
       19    0.000    0.000    0.000    0.000 jsonutil.py:102(date_default)
       19    0.000    0.000    0.000    0.000 {method 'getvalue' of '_io.StringIO' objects}
        1    0.000    0.000    8.550    8.550 {modules._topic_models.gibbs_sampler_blslda}
       76    0.000    0.000    0.000    0.000 hmac.py:90(update)
       19    0.000    0.000    0.000    0.000 hmac.py:127(hexdigest)
       19    0.000    0.000    0.000    0.000 {method 'digest' of '_hashlib.HASH' objects}
       19    0.000    0.000    0.000    0.000 {built-in method locals}
       97    0.000    0.000    0.000    0.000 {built-in method getpid}
        1    0.000    0.000    8.556    8.556 topic_models.py:271(fit)
      200    0.000    0.000    0.000    0.000 {method '__array_prepare__' of 'numpy.ndarray' objects}
       40    0.000    0.000    0.000    0.000 {built-in method time}
        1    0.000    0.000    8.556    8.556 {built-in method exec}
        4    0.000    0.000    0.002    0.001 fromnumeric.py:350(repeat)
       40    0.000    0.000    0.000    0.000 {method 'write' of '_io.StringIO' objects}
       19    0.000    0.000    0.000    0.000 {method 'hexdigest' of '_hashlib.HASH' objects}
       19    0.000    0.000    0.000    0.000 py3compat.py:18(encode)
       19    0.000    0.000    0.000    0.000 threading.py:1096(ident)
       38    0.000    0.000    0.000    0.000 {method 'extend' of 'list' objects}
        1    0.000    0.000    0.000    0.000 twodim_base.py:190(eye)
        5    0.000    0.000    0.000    0.000 numeric.py:516(ascontiguousarray)
        2    0.000    0.000    0.000    0.000 fromnumeric.py:40(_wrapit)
       19    0.000    0.000    0.000    0.000 {method 'close' of '_io.StringIO' objects}
        7    0.000    0.000    0.000    0.000 {method 'reshape' of 'numpy.ndarray' objects}
       19    0.000    0.000    0.000    0.000 {built-in method get_ident}
       57    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 numeric.py:2125(identity)
       76    0.000    0.000    0.000    0.000 {method 'join' of 'str' objects}
       19    0.000    0.000    0.000    0.000 {built-in method __new__ of type object at 0x10e6340f0}
       19    0.000    0.000    0.000    0.000 {method 'count' of 'list' objects}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:2264(_handle_fromlist)
       19    0.000    0.000    0.000    0.000 {method 'copy' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 base.py:865(isspmatrix)
       19    0.000    0.000    0.000    0.000 {method 'upper' of 'str' objects}
       19    0.000    0.000    0.000    0.000 {method 'group' of '_sre.SRE_Match' objects}
        1    0.000    0.000    0.002    0.002 fromnumeric.py:1380(nonzero)
        3    0.000    0.000    0.000    0.000 _methods.py:31(_sum)
        1    0.000    0.000    0.000    0.000 {built-in method zeros}
        1    0.000    0.000    0.000    0.000 {method 'sum' of 'numpy.ndarray' objects}
        4    0.000    0.000    0.000    0.000 stringsource:949(__get__)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


Out[12]:
<pstats.Stats at 0x115fdadd8>

other stuff


In [13]:
%%time
blslda.fit(doc_term_matrix, y)


2015-08-25 23:02:02.499818 start iterations
2015-08-25 23:02:02.816054 0:00:00.316236 elapsed, iter   10, LL -250839.0493, 13.35% change from last
2015-08-25 23:02:03.084356 0:00:00.584538 elapsed, iter   20, LL -238917.2485, 4.75% change from last
2015-08-25 23:02:03.369666 0:00:00.869848 elapsed, iter   30, LL -231327.6734, 3.18% change from last
2015-08-25 23:02:03.656577 0:00:01.156759 elapsed, iter   40, LL -225735.7314, 2.42% change from last
2015-08-25 23:02:03.937644 0:00:01.437826 elapsed, iter   50, LL -220733.3096, 2.22% change from last
2015-08-25 23:02:04.223371 0:00:01.723553 elapsed, iter   60, LL -216555.1285, 1.89% change from last
2015-08-25 23:02:04.503353 0:00:02.003535 elapsed, iter   70, LL -212642.5964, 1.81% change from last
2015-08-25 23:02:04.780910 0:00:02.281092 elapsed, iter   80, LL -211006.7754, 0.77% change from last
2015-08-25 23:02:05.061101 0:00:02.561283 elapsed, iter   90, LL -208124.3425, 1.37% change from last
2015-08-25 23:02:05.347142 0:00:02.847324 elapsed, iter  100, LL -205301.7756, 1.36% change from last
2015-08-25 23:02:05.630575 0:00:03.130757 elapsed, iter  110, LL -203570.5584, 0.84% change from last
2015-08-25 23:02:05.902224 0:00:03.402406 elapsed, iter  120, LL -200980.3643, 1.27% change from last
2015-08-25 23:02:06.187869 0:00:03.688051 elapsed, iter  130, LL -200127.0719, 0.42% change from last
2015-08-25 23:02:06.468666 0:00:03.968848 elapsed, iter  140, LL -198705.7298, 0.71% change from last
2015-08-25 23:02:06.750928 0:00:04.251110 elapsed, iter  150, LL -197738.3814, 0.49% change from last
2015-08-25 23:02:07.033691 0:00:04.533873 elapsed, iter  160, LL -197591.5563, 0.07% change from last
2015-08-25 23:02:07.306342 0:00:04.806524 elapsed, iter  170, LL -196815.0336, 0.39% change from last
2015-08-25 23:02:07.580268 0:00:05.080450 elapsed, iter  180, LL -195774.9359, 0.53% change from last
2015-08-25 23:02:07.859481 0:00:05.359663 elapsed, iter  190, LL -194476.6685, 0.66% change from last
CPU times: user 5.61 s, sys: 14.8 ms, total: 5.62 s
Wall time: 5.63 s

In [14]:
plot_images(plt, blslda.phi, (10, 10), (2, 5), figsize=(10, 5))



In [15]:
imshow(blslda.theta)


Out[15]:
<matplotlib.image.AxesImage at 0x117578860>

In [16]:
plt.plot(blslda.loglikelihoods)


Out[16]:
[<matplotlib.lines.Line2D at 0x1175d47b8>]

In [17]:
burn_in = 100
eta_pred = blslda.eta[burn_in:].mean(axis=0)
print('mean log-likelihood {}'.format(blslda.loglikelihoods[burn_in:].mean()))
eta_pred


mean log-likelihood -198916.3680838434
Out[17]:
array([ -7.65792495,  -8.81479835,  -7.21599757,   5.49959948,
        -4.39817516,  -5.01795172,  -7.74796161,  -7.27304832,
        -8.34087661,  14.49302745])