Collapsed Gibbs sampler for supervised latent Dirichlet allocation

$ \newcommand{\dir}{\mathop{\rm Dirichlet}\nolimits} \newcommand{\dis}{\mathop{\rm Discrete}\nolimits} \newcommand{\normal}{\mathop{\rm Normal}\nolimits} \newcommand{\ber}{\mathop{\rm Bernoulli}\nolimits} \newcommand{\btheta}{\mathbf{\theta}} \newcommand{\norm}[1]{\left\| #1 \right\|} \newcommand{\cp}[2]{p \left( #1 \middle| #2 \right)} \newcommand{\cN}[2]{\mathscr{N} \left( #1 \middle| #2 \right)} \newcommand{\Betaf}{\mathop{\rm B}\nolimits} \newcommand{\Gammaf}{\mathop{\Gamma}\nolimits} \newcommand{\etd}[1]{\mathbf{z}^{(#1)}} \newcommand{\sumetd}{\mathbf{z}} \newcommand{\one}{\mathbf{1}} $

Here is the collapsed Gibbs sampler for Blei and McAuliffe's supervised topic models. I am building on the collapsed Gibbs sampler I wrote for latent Dirichlet allocation.

The generative model for is as follows:

$$\begin{align} \theta^{(d)} &\sim \dir(\alpha) &\text{(topic distribution for document $d \in \{1, \ldots, D\}$)} \\ \phi^{(k)} &\sim \dir(\beta) &\text{(term distribution for topic $k \in \{1, \ldots, K\}$)} \\ z_n^{(d)} \mid \theta^{(d)} &\sim \dis \left( \theta^{(d)} \right) &\text{(topic of $n$th token of document $d$, $n \in \{1, \ldots, N^{(d)}\}$)} \\ w_n^{(d)} \mid \phi^{(z_n^{(d)})} &\sim \dis \left( \phi^{(z_n^{(d)})} \right) &\text{(term of $n$th token of document $d$, $n \in \{1, \ldots, N^{(d)}\}$)} \\ \eta_k &\sim \normal \left( \mu, \nu^2 \right) &\text{(regression coefficient for topic $k \in \{1, \ldots, K\}$)} \\ y^{(d)} \mid \eta, \etd{d} &\sim \normal \left( \eta \cdot \etd{d}, \sigma^2 \right) &\text{(response value of document $d \in \{1, \ldots, D\}$)} \end{align}$$

where each token can be any one of $V$ terms in our vocabulary, and $\etd{d}$ is the empirical topic distribution of document $d$.

Plate notation for supervised latent Dirichlet allocation.
This diagram should replace $\beta_k$ with $\phi^{(k)}$, and each $\phi^{(k)}$ should be dependent on a single $\beta$.

The joint probability distribution can be factored as follows:

$$\begin{align} \cp{\theta, \phi, z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} &= \prod_{k=1}^{K} \cp{\phi^{(k)}}{\beta} \prod_{d=1}^{D} \cp{\theta^{(d)}}{\alpha} \prod_{n=1}^{N^{(d)}} \cp{z_n^{(d)}}{\theta^{(d)}} \cp{w_n^{(d)}}{\phi^{(z_n^{(d)})}} \\ & \quad \times \prod_{k'=1}^{K} \cp{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^D \cp{y^{(d')}}{\eta, \etd{d'}, \sigma^2} \\ &= \prod_{k=1}^{K} \frac{\Betaf(b^{(k)} + \beta)}{\Betaf(\beta)} \cp{\phi^{(k)}}{b^{(k)} + \beta} \prod_{d=1}^{D} \frac{\Betaf(a^{(d)} + \alpha)}{\Betaf(\alpha)} \cp{\theta^{(d)}}{a^{(d)} + \alpha} \\ &\quad \times \prod_{k'=1}^{K} \cN{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \end{align}$$

where $a_k^{(d)}$ is the number of tokens in document $d$ assigned to topic $k$, $b_v^{(k)}$ is the number of tokens equal to term $v$ and assigned to topic $k$, and $\Betaf$ is the multivariate Beta function. Marginalizing out $\theta$ and $\phi$ by integrating with respect to each $\theta^{(d)}$ and $\phi^{(k)}$ over their respective sample spaces yields

$$\begin{align} \cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} &= \prod_{k=1}^{K} \frac{\Betaf(b^{(k)} + \beta)}{\Betaf(\beta)} \prod_{d=1}^{D} \frac{\Betaf(a^{(d)} + \alpha)}{\Betaf(\alpha)} \prod_{k'=1}^{K} \cN{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \\ &= \cp{w}{z, \beta} \cp{z}{\alpha} \cp{\eta}{\mu, \nu^2} \cp{y}{\eta, z, \sigma^2}. \end{align}$$

See my LDA notebook for step-by-step details of the previous two calculations.

Our goal is to calculate the posterior distribution

$$\cp{z, \eta}{w, y, \alpha, \beta, \mu, \nu^2, \sigma^2} = \frac{\cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2}} {\sum_{z'} \int \cp{z', w, \eta', y}{\alpha, \beta, \mu, \nu^2, \sigma^2} d\eta'}$$

in order to infer the topic assignments $z$ and regression coefficients $\eta$ from the given term assignments $w$ and response data $y$. Since calculating this directly is infeasible, we resort to collapsed Gibbs sampling. The sampler is "collapsed" because we marginalized out $\theta$ and $\phi$, and will estimate them from the topic assignments $z$:

$$\hat\theta_k^{(d)} = \frac{a_k^{(d)} + \alpha_k}{\sum_{k'=1}^K \left(a_{k'}^{(d)} + \alpha_{k'} \right)},\quad \hat\phi_v^{(k)} = \frac{b_v^{(k)} + \beta_v}{\sum_{v'=1}^V \left(b_{v'}^{(k)} + \beta_{v'} \right)}.$$

Gibbs sampling requires us to compute the full conditionals for each $z_n^{(d)}$ and $\eta_k$, i.e. we need to calculate, for all $n$, $d$ and $k$,

$$\begin{align} \cp{z_n^{(d)} = k}{z \setminus z_n^{(d)}, w, \eta, y, \alpha, \beta, \mu, \nu^2, \sigma^2} &\propto \cp{z_n^{(d)} = k, z \setminus z_n^{(d)}, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} \\ &\propto \frac{b_{w_n^{(d)}}^{(k)} \setminus z_n^{(d)} + \beta_{w_n^{(d)}}}{ \sum_{v=1}^V \left( b_v^{(k)} \setminus z_n^{(d)} + \beta_v\right)} \left( a_k^{(d)} \setminus z_n^{(d)} + \alpha_k \right) \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \\ &\propto \frac{b_{w_n^{(d)}}^{(k)} \setminus z_n^{(d)} + \beta_{w_n^{(d)}}}{ \sum_{v=1}^V \left( b_v^{(k)} \setminus z_n^{(d)} + \beta_v\right)} \left( a_k^{(d)} \setminus z_n^{(d)} + \alpha_k \right) \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \end{align}$$

where the "set-minus" notation $\cdot \setminus z_n^{(d)}$ denotes the variable the notation is applied to with the entry $z_n^{(d)}$ removed (again, see my LDA notebook for details). This final proportionality is true since

$$\begin{align} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} &\propto \prod_{d'=1}^{D} \exp \left( -\frac{ \left( y^{(d')} - \eta \cdot \etd{d'} \right)^2 }{2 \sigma^2} \right) \\ &\propto \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \eta \cdot \etd{d'} - \left( \eta \cdot \etd{d'} \right)^2 }{2 \sigma^2} \right) \\ &= \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) + \delta_{d, d'} \frac{\eta_k}{N^{(d)}} \right) - \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) + \delta_{d, d'} \frac{\eta_k}{N^{(d)}} \right)^2 }{2 \sigma^2} \right) \\ &= \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) - \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) \right)^2 }{2 \sigma^2} \right) \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \\ &\propto \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \end{align}$$

where $\delta_{d, d'}$ is the Kronecker delta.

We also need to calculate the full conditional for $\eta$. In order to do this, let $Z = (\etd{1} \cdots \etd{D})$ be the matrix whose columns are the empirical topic distributions $\etd{d}$, let $I$ be the identity matrix and $\one$ be the vector of ones, and note that

$$\prod_{k=1}^{K} \cN{\eta_{k}}{\mu, \nu^2} = \cN{\eta}{\mu \one, \nu^2 I}$$$$\prod_{d=1}^{D} \cN{y^{(d)}}{\eta \cdot \etd{d'}, \sigma^2} = \cN{y}{Z^T \eta, \sigma^2 I}.$$

Therefore

$$\begin{align} \cp{\eta}{z, w, y, \alpha, \beta, \mu, \nu^2, \sigma^2} &\propto \cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} \\ &\propto \cN{\eta}{\mu \one, \nu^2 I} \cN{y}{Z^T \eta, \sigma^2 I} \\ &\propto \cN{\eta}{\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right), \Sigma} \end{align}$$

where $\Sigma^{-1} = \nu^{-2} I + \sigma^{-2} ZZ^T$ (see Section 9.3 of Kevin Murphy's notes for a derivation of Bayes rule for linear Gaussian systems). It is interesting to consider the mean and variance of $\eta$ in the two variance regimes $\sigma \gg \nu$ and $\sigma \ll \nu$. If $\sigma \gg \nu$, then

$$\Sigma^{-1} = \nu^{-2} \left( I + \left( \frac{\nu}{\sigma} \right)^2 ZZ^T \right) \approx \nu^{-2} I$$

which implies that the covariance structure of $\eta$ is $\Sigma \approx \nu^2 I$ and the mean of $\eta$ is

$$\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right) \approx \mu \one + \left( \frac{\nu}{\sigma} \right)^2 Zy \approx \mu \one,$$

thus $\eta$ is approximately distributed according to its prior distribution. On the other hand, if $\sigma \ll \nu$, then

$$\Sigma^{-1} = \sigma^{-2} \left( \left( \frac{\sigma}{\nu} \right)^2 I + ZZ^T \right) \approx \sigma^{-2} ZZ^T.$$

Notice that $ZZ^T$ is almost surely positive definite, and hence almost surely invertible. Therefore $\Sigma \approx \sigma^2 (ZZ^T)^{-1}$ and

$$\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right) \approx \left( \frac{\sigma}{\nu} \right)^2 \mu (ZZ^T)^{-1} \one + (ZZ^T)^{-1} Zy \approx (ZZ^T)^{-1} Zy,$$

thus $\eta$ is approximately distributed as the least-squares solution of $y = Z^T \eta$.

Graphical test


In [2]:
%matplotlib inline

from modules.helpers import plot_images
from functools import partial
from sklearn.metrics import (mean_squared_error)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

imshow = partial(plt.imshow, cmap='gray', interpolation='nearest', aspect='auto')
rmse = lambda y_true, y_pred: np.sqrt(mean_squared_error(y_true, y_pred))
sns.set(style='white')

Generate topics

We assume a vocabulary of 25 terms, and create ten "topics", where each topic assigns exactly 5 consecutive terms equal probability.


In [3]:
V = 25
K = 10
N = 100
D = 1000

In [4]:
topics = []
topic_base = np.concatenate((np.ones((1, 5)) * 0.2, np.zeros((4, 5))), axis=0).ravel()
for i in range(5):
    topics.append(np.roll(topic_base, i * 5))
topic_base = np.concatenate((np.ones((5, 1)) * 0.2, np.zeros((5, 4))), axis=1).ravel()
for i in range(5):
    topics.append(np.roll(topic_base, i))
topics = np.array(topics)
plt.figure(figsize=(10, 5))
plot_images(plt, topics, (5, 5), layout=(2, 5), figsize=(10, 5))


<matplotlib.figure.Figure at 0x10f4ac860>

Generate documents from topics

We generate 1,000 documents from these 10 topics by sampling 1,000 topic distributions, one for each document, from a Dirichlet distribution with parameter $\alpha = (1, \ldots, 1)$.


In [5]:
alpha = np.ones(K)
np.random.seed(42)
thetas = np.random.dirichlet(alpha, size=D)
topic_assignments = np.array([np.random.choice(range(K), size=100, p=theta)
                              for theta in thetas])
word_assignments = np.array([[np.random.choice(range(V), size=1, p=topics[topic_assignments[d, n]])[0]
                              for n in range(N)] for d in range(D)])
doc_term_matrix = np.array([np.histogram(word_assignments[d], bins=V, range=(0, V - 1))[0] for d in range(D)])
imshow(doc_term_matrix)


Out[5]:
<matplotlib.image.AxesImage at 0x11225ada0>

Generate responses


In [6]:
# choose parameter values
nu2 = 10
sigma2 = 1
np.random.seed(42)
eta = np.random.normal(scale=nu2, size=K)
y = [np.dot(eta, thetas[i]) for i in range(D)] + np.random.normal(scale=sigma2, size=D)
# plot histogram of responses
print(eta)
_ = plt.hist(y, bins=20)


[  4.96714153  -1.38264301   6.47688538  15.23029856  -2.34153375
  -2.34136957  15.79212816   7.67434729  -4.69474386   5.42560044]

Estimate parameters


In [7]:
from slda.topic_models import SLDA

In [9]:
_K = 10
_alpha = alpha
_beta = np.repeat(0.01, V)
_mu = 0
_nu2 = nu2
_sigma2 = sigma2
n_iter = 500
slda = SLDA(_K, _alpha, _beta, _mu, _nu2, _sigma2, n_iter, seed=42)

In [10]:
%%time
slda.fit(doc_term_matrix, y)


CPU times: user 12.7 s, sys: 23.3 ms, total: 12.8 s
Wall time: 12.8 s

In [11]:
plot_images(plt, slda.phi, (5, 5), (2, 5), figsize=(10, 5))
print(slda.phi)
print(np.sum(slda.phi, axis=0))
print(np.sum(slda.phi, axis=1))


[[  2.09449783e-01   1.00358783e-06   1.00358783e-06   1.00358783e-06
    1.00358783e-06   1.92188072e-01   1.00358783e-06   1.00358783e-06
    1.00358783e-06   1.00358783e-06   1.85965828e-01   1.00358783e-06
    1.10495020e-03   1.00358783e-06   1.00358783e-06   1.88876233e-01
    1.00358783e-06   1.00358783e-06   1.00358783e-06   1.00358783e-06
    2.22396066e-01   1.00358783e-06   1.00358783e-06   1.00358783e-06
    1.00358783e-06]
 [  1.94239884e-01   1.85987570e-01   2.05277355e-01   1.99810197e-01
    2.10331898e-01   1.03153931e-06   1.03153931e-06   1.03153931e-06
    1.03153931e-06   1.03153931e-06   1.03153931e-06   1.03153931e-06
    1.03153931e-06   1.03153931e-06   1.03153931e-06   1.03153931e-06
    1.03153931e-06   1.03153931e-06   4.33349666e-03   1.03153931e-06
    1.03153931e-06   1.03153931e-06   1.03153931e-06   1.03153931e-06
    1.03153931e-06]
 [  9.45067927e-07   9.45067927e-07   9.45067927e-07   9.45067927e-07
    9.45067927e-07   2.01205907e-01   2.02339988e-01   2.00260839e-01
    2.03096043e-01   1.93078323e-01   9.45067927e-07   9.45067927e-07
    9.45067927e-07   9.45067927e-07   9.45067927e-07   9.45067927e-07
    9.45067927e-07   9.45067927e-07   9.45067927e-07   9.45067927e-07
    9.45067927e-07   9.45067927e-07   9.45067927e-07   9.45067927e-07
    9.45067927e-07]
 [  9.74730122e-07   9.74730122e-07   9.84477423e-05   4.58220630e-03
    9.74730122e-07   2.73021907e-03   9.74730122e-07   9.74730122e-07
    9.74730122e-07   9.74730122e-07   9.74730122e-07   9.74730122e-07
    9.74730122e-07   9.74730122e-07   9.74730122e-07   2.05766503e-01
    1.95824256e-01   1.89488510e-01   2.08008383e-01   1.93484904e-01
    9.74730122e-07   9.74730122e-07   9.74730122e-07   9.74730122e-07
    9.74730122e-07]
 [  1.02257331e-06   1.02257331e-06   1.02257331e-06   1.02257331e-06
    1.02257331e-06   1.02257331e-06   1.02257331e-06   1.02257331e-06
    1.02257331e-06   1.02257331e-06   2.13923358e-01   2.03288596e-01
    1.87131938e-01   1.95108009e-01   2.00527648e-01   1.02257331e-06
    1.02257331e-06   1.02257331e-06   1.02257331e-06   1.02257331e-06
    1.02257331e-06   1.02257331e-06   1.02257331e-06   1.02257331e-06
    1.02257331e-06]
 [  9.87532403e-07   9.87532403e-07   1.99482533e-01   9.87532403e-07
    9.87532403e-07   9.87532403e-07   9.87532403e-07   1.85064560e-01
    9.87532403e-07   9.87532403e-07   9.87532403e-07   9.87532403e-07
    2.11925441e-01   9.87532403e-07   9.87532403e-07   9.87532403e-07
    9.87532403e-07   2.01556351e-01   9.87532403e-07   9.87532403e-07
    9.87532403e-07   9.87532403e-07   2.01951364e-01   9.87532403e-07
    9.87532403e-07]
 [  1.02362003e-06   1.02362003e-06   1.02362003e-06   1.93669934e-01
    1.02362003e-06   1.02362003e-06   1.02362003e-06   1.02362003e-06
    1.93772296e-01   1.02362003e-06   1.02362003e-06   1.02362003e-06
    1.02362003e-06   2.05851012e-01   1.78120122e-02   1.03385623e-04
    1.02362003e-06   1.02362003e-06   1.94795916e-01   1.02362003e-06
    1.02362003e-06   1.02362003e-06   1.02362003e-06   1.93977020e-01
    1.02362003e-06]
 [  9.83066676e-07   1.99170292e-01   9.83066676e-07   9.83066676e-07
    9.83066676e-07   9.83066676e-07   1.88749785e-01   9.83066676e-07
    9.83066676e-07   9.83066676e-07   9.83066676e-07   2.01824572e-01
    9.83066676e-07   9.83066676e-07   9.83066676e-07   9.83066676e-07
    2.10377252e-01   9.83066676e-07   9.83066676e-07   9.83066676e-07
    9.83066676e-07   1.99858438e-01   9.83066676e-07   9.83066676e-07
    9.83066676e-07]
 [  1.02825120e-06   1.02825120e-06   1.02825120e-06   1.02825120e-06
    1.86834272e-01   1.02825120e-06   1.02825120e-06   1.02825120e-06
    1.02825120e-06   2.00510013e-01   3.70273258e-03   1.02825120e-06
    1.02825120e-06   1.02825120e-06   1.89610550e-01   1.02825120e-06
    1.02825120e-06   1.03853371e-04   1.02825120e-06   2.12746202e-01
    1.02825120e-06   1.02825120e-06   1.02825120e-06   1.02825120e-06
    2.06473870e-01]
 [  1.00692259e-06   1.00692259e-06   1.00692259e-06   1.00692259e-06
    1.00692259e-06   1.00692259e-06   1.00692259e-06   1.00692259e-06
    1.00692259e-06   1.00692259e-06   1.00692259e-06   1.00692259e-06
    1.00692259e-06   1.00692259e-06   1.00692259e-06   5.63977344e-03
    1.00692259e-06   1.00692259e-06   1.00692259e-06   1.00692259e-06
    1.95444682e-01   1.96552297e-01   1.91215607e-01   2.17093518e-01
    1.94034991e-01]]
[ 0.40369764  0.38516585  0.40486535  0.39806931  0.39717412  0.39613128
  0.39109785  0.38533347  0.39687638  0.39359637  0.40359887  0.40512117
  0.40016932  0.40096698  0.40795714  0.40039189  0.40620956  0.39115573
  0.40714477  0.40623911  0.41784874  0.39641875  0.39317498  0.41107851
  0.40051683]
[ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]

In [12]:
topic_order = [1, 2, 4, 3, 9, 0, 7, 5, 6, 8]
plot_images(plt, slda.phi[topic_order], (5, 5), (2, 5), figsize=(10, 5))



In [13]:
imshow(slda.theta)


Out[13]:
<matplotlib.image.AxesImage at 0x111890780>

In [14]:
plt.plot(slda.loglikelihoods)


Out[14]:
[<matplotlib.lines.Line2D at 0x111904ba8>]

In [15]:
plt.plot(np.diff(slda.loglikelihoods)[-100:])


Out[15]:
[<matplotlib.lines.Line2D at 0x111663550>]

In [13]:
burn_in = max(n_iter - 100, int(n_iter / 2))
slda.loglikelihoods[burn_in:].mean()


Out[13]:
29182.323844355378

In [14]:
eta_pred = slda.eta[burn_in:].mean(axis=0)
print(eta)
print(eta_pred[topic_order])
np.linalg.norm(eta - eta_pred[topic_order])


[  4.96714153  -1.38264301   6.47688538  15.23029856  -2.34153375
  -2.34136957  15.79212816   7.67434729  -4.69474386   5.42560044]
[  4.44236917  -1.37428922   5.69171791  13.92661777  -1.90987367
  -1.35671636  15.39498326   7.42637763  -2.77220459   4.94725851]
Out[14]:
2.8091957908971406

Predict response of test documents

Create 1,000 test documents using the same generative process as our training documents, and compute their actual responses.


In [15]:
np.random.seed(42^2)
thetas_test = np.random.dirichlet(np.ones(K), size=D)
topic_assignments_test = np.array([np.random.choice(range(K), size=100, p=theta)
                                   for theta in thetas_test])
word_assignments_test = np.array([[np.random.choice(range(V), size=1, p=topics[topic_assignments_test[d, n]])[0]
                                   for n in range(N)] for d in range(D)])
doc_term_matrix_test = np.array([np.histogram(word_assignments_test[d], bins=V, range=(0, V - 1))[0] for d in range(D)])
y_test = [np.dot(eta, thetas_test[i]) for i in range(D)]
imshow(doc_term_matrix_test)


Out[15]:
<matplotlib.image.AxesImage at 0x10e4d4c50>

Estimate their topic distributions using the trained model, then calculate the predicted responses using the mean of our samples of $\eta$ after burn-in as an estimate for $\eta$.


In [16]:
thetas_test_slda = slda.transform(doc_term_matrix_test)
y_slda = [np.dot(eta_pred, thetas_test_slda[i]) for i in range(D)]

Measure the goodness of our prediction using root mean square error.


In [17]:
rmse(y_test, y_slda)


Out[17]:
0.82113166468454046

Two-step learning: learn topics, then learn regression


In [18]:
from slda.topic_models import LDA

In [19]:
lda = LDA(_K, _alpha, _beta, n_iter, seed=42)

In [20]:
%%time
lda.fit(doc_term_matrix)


2016-06-10 13:13:01.718094 start iterations
2016-06-10 13:13:01.810971 0:00:00.092877 elapsed, iter   10, LL -518559.0163, 8.97% change from last
2016-06-10 13:13:01.886179 0:00:00.168085 elapsed, iter   20, LL -468960.0555, 9.56% change from last
2016-06-10 13:13:01.956784 0:00:00.238690 elapsed, iter   30, LL -445434.0461, 5.02% change from last
2016-06-10 13:13:02.026697 0:00:00.308603 elapsed, iter   40, LL -429314.6963, 3.62% change from last
2016-06-10 13:13:02.099862 0:00:00.381768 elapsed, iter   50, LL -419699.1884, 2.24% change from last
2016-06-10 13:13:02.167735 0:00:00.449641 elapsed, iter   60, LL -415442.5080, 1.01% change from last
2016-06-10 13:13:02.234883 0:00:00.516789 elapsed, iter   70, LL -408896.4141, 1.58% change from last
2016-06-10 13:13:02.301891 0:00:00.583797 elapsed, iter   80, LL -402998.4074, 1.44% change from last
2016-06-10 13:13:02.368526 0:00:00.650432 elapsed, iter   90, LL -398554.0889, 1.10% change from last
2016-06-10 13:13:02.437271 0:00:00.719177 elapsed, iter  100, LL -396059.2952, 0.63% change from last
2016-06-10 13:13:02.507638 0:00:00.789544 elapsed, iter  110, LL -393946.6660, 0.53% change from last
2016-06-10 13:13:02.573177 0:00:00.855083 elapsed, iter  120, LL -393350.3096, 0.15% change from last
2016-06-10 13:13:02.638212 0:00:00.920118 elapsed, iter  130, LL -391582.5349, 0.45% change from last
2016-06-10 13:13:02.703620 0:00:00.985526 elapsed, iter  140, LL -389979.6403, 0.41% change from last
2016-06-10 13:13:02.769137 0:00:01.051043 elapsed, iter  150, LL -390480.8703, -0.13% change from last
2016-06-10 13:13:02.839140 0:00:01.121046 elapsed, iter  160, LL -389499.9872, 0.25% change from last
2016-06-10 13:13:02.904864 0:00:01.186770 elapsed, iter  170, LL -389235.6565, 0.07% change from last
2016-06-10 13:13:02.969981 0:00:01.251887 elapsed, iter  180, LL -388759.9274, 0.12% change from last
2016-06-10 13:13:03.034976 0:00:01.316882 elapsed, iter  190, LL -388820.9941, -0.02% change from last
2016-06-10 13:13:03.100102 0:00:01.382008 elapsed, iter  200, LL -388260.5027, 0.14% change from last
2016-06-10 13:13:03.170388 0:00:01.452294 elapsed, iter  210, LL -388249.9451, 0.00% change from last
2016-06-10 13:13:03.236702 0:00:01.518608 elapsed, iter  220, LL -388140.9833, 0.03% change from last
2016-06-10 13:13:03.301888 0:00:01.583794 elapsed, iter  230, LL -388708.6353, -0.15% change from last
2016-06-10 13:13:03.368221 0:00:01.650127 elapsed, iter  240, LL -388470.8166, 0.06% change from last
2016-06-10 13:13:03.433243 0:00:01.715149 elapsed, iter  250, LL -388586.6879, -0.03% change from last
2016-06-10 13:13:03.503369 0:00:01.785275 elapsed, iter  260, LL -388759.3606, -0.04% change from last
2016-06-10 13:13:03.569231 0:00:01.851137 elapsed, iter  270, LL -388112.9995, 0.17% change from last
2016-06-10 13:13:03.634155 0:00:01.916061 elapsed, iter  280, LL -388325.5728, -0.05% change from last
2016-06-10 13:13:03.699221 0:00:01.981127 elapsed, iter  290, LL -388194.6193, 0.03% change from last
2016-06-10 13:13:03.764164 0:00:02.046070 elapsed, iter  300, LL -388022.2504, 0.04% change from last
2016-06-10 13:13:03.829746 0:00:02.111652 elapsed, iter  310, LL -388035.5122, -0.00% change from last
2016-06-10 13:13:03.901004 0:00:02.182910 elapsed, iter  320, LL -387834.7560, 0.05% change from last
2016-06-10 13:13:03.973892 0:00:02.255798 elapsed, iter  330, LL -387896.1408, -0.02% change from last
2016-06-10 13:13:04.045441 0:00:02.327347 elapsed, iter  340, LL -387689.7851, 0.05% change from last
2016-06-10 13:13:04.115698 0:00:02.397604 elapsed, iter  350, LL -387410.1707, 0.07% change from last
2016-06-10 13:13:04.186569 0:00:02.468475 elapsed, iter  360, LL -387194.6374, 0.06% change from last
2016-06-10 13:13:04.257784 0:00:02.539690 elapsed, iter  370, LL -387956.5177, -0.20% change from last
2016-06-10 13:13:04.328896 0:00:02.610802 elapsed, iter  380, LL -387593.8931, 0.09% change from last
2016-06-10 13:13:04.399921 0:00:02.681827 elapsed, iter  390, LL -387841.1371, -0.06% change from last
2016-06-10 13:13:04.470958 0:00:02.752864 elapsed, iter  400, LL -388194.7104, -0.09% change from last
2016-06-10 13:13:04.541473 0:00:02.823379 elapsed, iter  410, LL -387557.6239, 0.16% change from last
2016-06-10 13:13:04.612267 0:00:02.894173 elapsed, iter  420, LL -387654.8058, -0.03% change from last
2016-06-10 13:13:04.682486 0:00:02.964392 elapsed, iter  430, LL -388271.5004, -0.16% change from last
2016-06-10 13:13:04.753022 0:00:03.034928 elapsed, iter  440, LL -387805.9252, 0.12% change from last
2016-06-10 13:13:04.823710 0:00:03.105616 elapsed, iter  450, LL -388152.3984, -0.09% change from last
2016-06-10 13:13:04.894584 0:00:03.176490 elapsed, iter  460, LL -388539.1905, -0.10% change from last
2016-06-10 13:13:04.965270 0:00:03.247176 elapsed, iter  470, LL -387818.0986, 0.19% change from last
2016-06-10 13:13:05.035921 0:00:03.317827 elapsed, iter  480, LL -387694.9035, 0.03% change from last
2016-06-10 13:13:05.106130 0:00:03.388036 elapsed, iter  490, LL -387909.3207, -0.06% change from last
CPU times: user 3.45 s, sys: 15.4 ms, total: 3.46 s
Wall time: 3.46 s

In [21]:
plot_images(plt, lda.phi, (5, 5), (2, 5), figsize=(10, 5))



In [22]:
topic_order_lda = [1, 2, 0, 3, 9, 4, 7, 5, 6, 8]
plot_images(plt, lda.phi[topic_order_lda], (5, 5), (2, 5), figsize=(10, 5))



In [23]:
imshow(lda.theta)


Out[23]:
<matplotlib.image.AxesImage at 0x1135bf208>

In [24]:
plt.plot(lda.loglikelihoods)


Out[24]:
[<matplotlib.lines.Line2D at 0x1135b2630>]

In [25]:
thetas_test_lda = lda.transform(doc_term_matrix_test)

Unregularized linear regression

  • train linear regression on training data
  • calculate response on test data
  • measure the goodness of prediction using root mean square error.

In [26]:
from sklearn.linear_model import LinearRegression

In [27]:
lr = LinearRegression(fit_intercept=False)
lr.fit(lda.theta, y)
y_lr = lr.predict(thetas_test_lda)
rmse(y_test, y_lr)


Out[27]:
0.82647226636711646

In [28]:
print(eta)
print(lr.coef_[topic_order_lda])
np.linalg.norm(eta - lr.coef_[topic_order_lda])


[  4.96714153  -1.38264301   6.47688538  15.23029856  -2.34153375
  -2.34136957  15.79212816   7.67434729  -4.69474386   5.42560044]
[  4.29349447  -1.35662767   5.86690601  13.56743486  -2.225693    -1.3096669
  15.68197752   7.74802135  -2.83164145   4.58107028]
Out[28]:
2.978496259038288

L2-regularized linear regression

  • train ridge regression on training data
  • calculate response on test data
  • measure the goodness of prediction using root mean square error.

In [29]:
from sklearn.linear_model import Ridge

In [30]:
lrl2 = Ridge(alpha=1., fit_intercept=False)
lrl2.fit(lda.theta, y)
y_lrl2 = lrl2.predict(thetas_test_lda)
rmse(y_test, y_lrl2)


Out[30]:
0.86174658080698563

In [31]:
print(eta)
print(lrl2.coef_[topic_order_lda])
np.linalg.norm(eta - lrl2.coef_[topic_order_lda])


[  4.96714153  -1.38264301   6.47688538  15.23029856  -2.34153375
  -2.34136957  15.79212816   7.67434729  -4.69474386   5.42560044]
[  4.28631637  -0.95278145   5.67416731  12.66995942  -1.48722133
  -0.7396339   14.37412703   7.40026967  -2.0793953    4.49485813]
Out[31]:
4.5755413457461787

Gradient boosted regression trees

  • train gradient boosted regressor on training data
  • calculate response on test data
  • measure the goodness of prediction using root mean square error.

In [32]:
from sklearn.ensemble import GradientBoostingRegressor

In [33]:
gbr = GradientBoostingRegressor()
gbr.fit(lda.theta, y)
y_gbr = gbr.predict(thetas_test_lda)
rmse(y_test, y_gbr)


Out[33]:
0.98198962236756349

Conclusion

SLDA is slightly better than unregularized linear regression, and better than ridge regression or gradient boosted regression trees. The similar performance to unregularized linear regression is likely due to the fact that this test was set up as an exact problem - all parameters used in training, except $\beta$ (because we hand-picked the topics) and $\eta$ (because that was the one parameter we wanted to learn), were those used to generate documents, and in SLDA the likelihood of $y$ introduces a regularization on $\eta$ in the full-conditional for $z$, while unregularized linear regression does not enforce such a penalty.

Test with fewer topics

We now redo the previous test, but this time with fewer topics, in order to determine whether

  1. the supervised portion of SLDA will produce topics different from those produced by LDA, and
  2. prediction with SLDA is improved over LDA-and-a-regression.

Because we will no longer be solving an exact problem (the number of topics, and hence $\alpha$, will both be different from the document generation process), we expect SLDA to do better than LDA-and-a-regression, including unregularized linear regression.


In [34]:
_K = 5
_alpha = np.repeat(1. / _K, _K)
_beta = np.repeat(0.01, V)
_mu = 0
_nu2 = nu2
_sigma2 = sigma2
n_iter = 500
slda1 = SLDA(_K, _alpha, _beta, _mu, _nu2, _sigma2, n_iter, seed=42)

In [35]:
%%time
slda1.fit(doc_term_matrix, y)


2016-06-10 13:13:08.127915 start iterations
2016-06-10 13:13:08.283610 0:00:00.155695 elapsed, iter   10, LL -48424.3678, 57.47% change from last
2016-06-10 13:13:08.426567 0:00:00.298652 elapsed, iter   20, LL -26669.8841, 44.92% change from last
2016-06-10 13:13:08.565257 0:00:00.437342 elapsed, iter   30, LL -23012.4319, 13.71% change from last
2016-06-10 13:13:08.697874 0:00:00.569959 elapsed, iter   40, LL -22168.0622, 3.67% change from last
2016-06-10 13:13:08.833184 0:00:00.705269 elapsed, iter   50, LL -20919.7416, 5.63% change from last
2016-06-10 13:13:08.971179 0:00:00.843264 elapsed, iter   60, LL -20442.9833, 2.28% change from last
2016-06-10 13:13:09.113854 0:00:00.985939 elapsed, iter   70, LL -21271.4782, -4.05% change from last
2016-06-10 13:13:09.262225 0:00:01.134310 elapsed, iter   80, LL -21511.5328, -1.13% change from last
2016-06-10 13:13:09.404185 0:00:01.276270 elapsed, iter   90, LL -21498.6877, 0.06% change from last
2016-06-10 13:13:09.546102 0:00:01.418187 elapsed, iter  100, LL -21450.9848, 0.22% change from last
2016-06-10 13:13:09.689745 0:00:01.561830 elapsed, iter  110, LL -21264.0617, 0.87% change from last
2016-06-10 13:13:09.832946 0:00:01.705031 elapsed, iter  120, LL -21124.3302, 0.66% change from last
2016-06-10 13:13:09.972801 0:00:01.844886 elapsed, iter  130, LL -21409.3146, -1.35% change from last
2016-06-10 13:13:10.114312 0:00:01.986397 elapsed, iter  140, LL -20748.2025, 3.09% change from last
2016-06-10 13:13:10.254950 0:00:02.127035 elapsed, iter  150, LL -21784.6926, -5.00% change from last
2016-06-10 13:13:10.399699 0:00:02.271784 elapsed, iter  160, LL -21593.1937, 0.88% change from last
2016-06-10 13:13:10.543634 0:00:02.415719 elapsed, iter  170, LL -21755.2391, -0.75% change from last
2016-06-10 13:13:10.679132 0:00:02.551217 elapsed, iter  180, LL -21838.7155, -0.38% change from last
2016-06-10 13:13:10.812856 0:00:02.684941 elapsed, iter  190, LL -22390.0410, -2.52% change from last
2016-06-10 13:13:10.951874 0:00:02.823959 elapsed, iter  200, LL -21844.0314, 2.44% change from last
2016-06-10 13:13:11.088078 0:00:02.960163 elapsed, iter  210, LL -22356.2309, -2.34% change from last
2016-06-10 13:13:11.221326 0:00:03.093411 elapsed, iter  220, LL -22476.3578, -0.54% change from last
2016-06-10 13:13:11.357571 0:00:03.229656 elapsed, iter  230, LL -21854.6413, 2.77% change from last
2016-06-10 13:13:11.494444 0:00:03.366529 elapsed, iter  240, LL -22209.4309, -1.62% change from last
2016-06-10 13:13:11.635861 0:00:03.507946 elapsed, iter  250, LL -22552.7551, -1.55% change from last
2016-06-10 13:13:11.771394 0:00:03.643479 elapsed, iter  260, LL -22492.8254, 0.27% change from last
2016-06-10 13:13:11.904591 0:00:03.776676 elapsed, iter  270, LL -21363.6117, 5.02% change from last
2016-06-10 13:13:12.039422 0:00:03.911507 elapsed, iter  280, LL -21755.4643, -1.83% change from last
2016-06-10 13:13:12.174180 0:00:04.046265 elapsed, iter  290, LL -21026.2013, 3.35% change from last
2016-06-10 13:13:12.310638 0:00:04.182723 elapsed, iter  300, LL -21019.0059, 0.03% change from last
2016-06-10 13:13:12.445832 0:00:04.317917 elapsed, iter  310, LL -21956.2779, -4.46% change from last
2016-06-10 13:13:12.581865 0:00:04.453950 elapsed, iter  320, LL -21733.3422, 1.02% change from last
2016-06-10 13:13:12.718181 0:00:04.590266 elapsed, iter  330, LL -21173.1530, 2.58% change from last
2016-06-10 13:13:12.858645 0:00:04.730730 elapsed, iter  340, LL -21356.0593, -0.86% change from last
2016-06-10 13:13:12.994323 0:00:04.866408 elapsed, iter  350, LL -20630.9341, 3.40% change from last
2016-06-10 13:13:13.128332 0:00:05.000417 elapsed, iter  360, LL -20467.3912, 0.79% change from last
2016-06-10 13:13:13.263090 0:00:05.135175 elapsed, iter  370, LL -21047.2518, -2.83% change from last
2016-06-10 13:13:13.402083 0:00:05.274168 elapsed, iter  380, LL -21253.7650, -0.98% change from last
2016-06-10 13:13:13.539833 0:00:05.411918 elapsed, iter  390, LL -21504.8653, -1.18% change from last
2016-06-10 13:13:13.677880 0:00:05.549965 elapsed, iter  400, LL -21712.5750, -0.97% change from last
2016-06-10 13:13:13.811928 0:00:05.684013 elapsed, iter  410, LL -22065.4088, -1.63% change from last
2016-06-10 13:13:13.948051 0:00:05.820136 elapsed, iter  420, LL -21642.4839, 1.92% change from last
2016-06-10 13:13:14.082922 0:00:05.955007 elapsed, iter  430, LL -21622.6513, 0.09% change from last
2016-06-10 13:13:14.224415 0:00:06.096500 elapsed, iter  440, LL -22010.7755, -1.79% change from last
2016-06-10 13:13:14.359776 0:00:06.231861 elapsed, iter  450, LL -22175.4719, -0.75% change from last
2016-06-10 13:13:14.495657 0:00:06.367742 elapsed, iter  460, LL -22189.0535, -0.06% change from last
2016-06-10 13:13:14.630771 0:00:06.502856 elapsed, iter  470, LL -21675.9985, 2.31% change from last
2016-06-10 13:13:14.766175 0:00:06.638260 elapsed, iter  480, LL -21900.0566, -1.03% change from last
2016-06-10 13:13:14.900684 0:00:06.772769 elapsed, iter  490, LL -21701.0455, 0.91% change from last
CPU times: user 6.88 s, sys: 33.5 ms, total: 6.91 s
Wall time: 6.91 s

In [36]:
plot_images(plt, slda1.phi, (5, 5), (1, 5), figsize=(10, 5))



In [37]:
imshow(slda1.theta)


Out[37]:
<matplotlib.image.AxesImage at 0x11404b588>

In [38]:
plt.plot(slda1.loglikelihoods)


Out[38]:
[<matplotlib.lines.Line2D at 0x11408fdd8>]

In [39]:
burn_in1 = max(n_iter - 100, int(n_iter / 2))
slda1.loglikelihoods[burn_in1:].mean()


Out[39]:
-22055.005897817184

In [40]:
eta_pred1 = slda1.eta[burn_in1:].mean(axis=0)
eta_pred1


Out[40]:
array([  2.22220264,  -0.15788853,   9.55428965,  10.2525672 ,   3.22556876])

In [41]:
thetas_test_slda1 = slda1.transform(doc_term_matrix_test)
y_slda1 = [np.dot(eta_pred1, thetas_test_slda1[i]) for i in range(D)]

In [42]:
rmse(y_test, y_slda1)


Out[42]:
1.2455951548532411

In [43]:
lda1 = LDA(_K, _alpha, _beta, n_iter, seed=42)

In [44]:
%%time
lda1.fit(doc_term_matrix)


2016-06-10 13:13:16.087473 start iterations
2016-06-10 13:13:16.143708 0:00:00.056235 elapsed, iter   10, LL -404994.8904, 13.25% change from last
2016-06-10 13:13:16.188322 0:00:00.100849 elapsed, iter   20, LL -380416.0984, 6.07% change from last
2016-06-10 13:13:16.232188 0:00:00.144715 elapsed, iter   30, LL -374718.7031, 1.50% change from last
2016-06-10 13:13:16.275514 0:00:00.188041 elapsed, iter   40, LL -372662.5340, 0.55% change from last
2016-06-10 13:13:16.318583 0:00:00.231110 elapsed, iter   50, LL -372592.3679, 0.02% change from last
2016-06-10 13:13:16.361522 0:00:00.274049 elapsed, iter   60, LL -371993.9718, 0.16% change from last
2016-06-10 13:13:16.404200 0:00:00.316727 elapsed, iter   70, LL -370991.9632, 0.27% change from last
2016-06-10 13:13:16.446894 0:00:00.359421 elapsed, iter   80, LL -372006.8433, -0.27% change from last
2016-06-10 13:13:16.490452 0:00:00.402979 elapsed, iter   90, LL -372236.5623, -0.06% change from last
2016-06-10 13:13:16.532751 0:00:00.445278 elapsed, iter  100, LL -373132.6808, -0.24% change from last
2016-06-10 13:13:16.575374 0:00:00.487901 elapsed, iter  110, LL -372307.9178, 0.22% change from last
2016-06-10 13:13:16.618148 0:00:00.530675 elapsed, iter  120, LL -373064.4300, -0.20% change from last
2016-06-10 13:13:16.661584 0:00:00.574111 elapsed, iter  130, LL -372896.9690, 0.04% change from last
2016-06-10 13:13:16.704663 0:00:00.617190 elapsed, iter  140, LL -373666.8725, -0.21% change from last
2016-06-10 13:13:16.747776 0:00:00.660303 elapsed, iter  150, LL -375130.4420, -0.39% change from last
2016-06-10 13:13:16.791435 0:00:00.703962 elapsed, iter  160, LL -375266.7223, -0.04% change from last
2016-06-10 13:13:16.834670 0:00:00.747197 elapsed, iter  170, LL -375293.3698, -0.01% change from last
2016-06-10 13:13:16.878672 0:00:00.791199 elapsed, iter  180, LL -376238.5505, -0.25% change from last
2016-06-10 13:13:16.922925 0:00:00.835452 elapsed, iter  190, LL -375997.5441, 0.06% change from last
2016-06-10 13:13:16.966738 0:00:00.879265 elapsed, iter  200, LL -375560.6490, 0.12% change from last
2016-06-10 13:13:17.011130 0:00:00.923657 elapsed, iter  210, LL -375438.6514, 0.03% change from last
2016-06-10 13:13:17.053203 0:00:00.965730 elapsed, iter  220, LL -375317.4478, 0.03% change from last
2016-06-10 13:13:17.100620 0:00:01.013147 elapsed, iter  230, LL -375611.4260, -0.08% change from last
2016-06-10 13:13:17.142928 0:00:01.055455 elapsed, iter  240, LL -375084.7172, 0.14% change from last
2016-06-10 13:13:17.186384 0:00:01.098911 elapsed, iter  250, LL -374881.3455, 0.05% change from last
2016-06-10 13:13:17.230083 0:00:01.142610 elapsed, iter  260, LL -375142.8976, -0.07% change from last
2016-06-10 13:13:17.275077 0:00:01.187604 elapsed, iter  270, LL -374933.0420, 0.06% change from last
2016-06-10 13:13:17.317606 0:00:01.230133 elapsed, iter  280, LL -374389.2267, 0.15% change from last
2016-06-10 13:13:17.359826 0:00:01.272353 elapsed, iter  290, LL -373889.6761, 0.13% change from last
2016-06-10 13:13:17.406300 0:00:01.318827 elapsed, iter  300, LL -373999.4793, -0.03% change from last
2016-06-10 13:13:17.454818 0:00:01.367345 elapsed, iter  310, LL -374055.1543, -0.01% change from last
2016-06-10 13:13:17.498775 0:00:01.411302 elapsed, iter  320, LL -374575.8085, -0.14% change from last
2016-06-10 13:13:17.543644 0:00:01.456171 elapsed, iter  330, LL -375653.2470, -0.29% change from last
2016-06-10 13:13:17.587968 0:00:01.500495 elapsed, iter  340, LL -376213.0731, -0.15% change from last
2016-06-10 13:13:17.630953 0:00:01.543480 elapsed, iter  350, LL -375100.3852, 0.30% change from last
2016-06-10 13:13:17.675820 0:00:01.588347 elapsed, iter  360, LL -374749.8943, 0.09% change from last
2016-06-10 13:13:17.719917 0:00:01.632444 elapsed, iter  370, LL -374970.1822, -0.06% change from last
2016-06-10 13:13:17.765062 0:00:01.677589 elapsed, iter  380, LL -375792.9562, -0.22% change from last
2016-06-10 13:13:17.809414 0:00:01.721941 elapsed, iter  390, LL -375580.5170, 0.06% change from last
2016-06-10 13:13:17.853399 0:00:01.765926 elapsed, iter  400, LL -376051.9297, -0.13% change from last
2016-06-10 13:13:17.896977 0:00:01.809504 elapsed, iter  410, LL -375554.2968, 0.13% change from last
2016-06-10 13:13:17.940666 0:00:01.853193 elapsed, iter  420, LL -375327.1425, 0.06% change from last
2016-06-10 13:13:17.984309 0:00:01.896836 elapsed, iter  430, LL -374650.8399, 0.18% change from last
2016-06-10 13:13:18.028397 0:00:01.940924 elapsed, iter  440, LL -375878.8372, -0.33% change from last
2016-06-10 13:13:18.072662 0:00:01.985189 elapsed, iter  450, LL -376199.3532, -0.09% change from last
2016-06-10 13:13:18.116349 0:00:02.028876 elapsed, iter  460, LL -375946.0416, 0.07% change from last
2016-06-10 13:13:18.160254 0:00:02.072781 elapsed, iter  470, LL -376240.8013, -0.08% change from last
2016-06-10 13:13:18.203976 0:00:02.116503 elapsed, iter  480, LL -376057.7271, 0.05% change from last
2016-06-10 13:13:18.247017 0:00:02.159544 elapsed, iter  490, LL -376733.2730, -0.18% change from last
CPU times: user 2.2 s, sys: 13.7 ms, total: 2.21 s
Wall time: 2.22 s

In [45]:
plot_images(plt, lda1.phi, (5, 5), (1, 5), figsize=(10, 5))


We plot the SLDA topics again and note that they are indeed different!


In [46]:
plot_images(plt, slda1.phi, (5, 5), (1, 5), figsize=(10, 5))



In [47]:
imshow(lda1.theta)


Out[47]:
<matplotlib.image.AxesImage at 0x10e340748>

In [48]:
plt.plot(lda1.loglikelihoods)


Out[48]:
[<matplotlib.lines.Line2D at 0x10a47a8d0>]

In [49]:
thetas_test_lda1 = lda1.transform(doc_term_matrix_test)

Unregularized linear regression


In [50]:
lr1 = LinearRegression(fit_intercept=False)
lr1.fit(lda1.theta, y)
y_lr1 = lr1.predict(thetas_test_lda1)
rmse(y_test, y_lr1)


Out[50]:
1.7255065570452643

L2-regularized linear regression


In [51]:
lrl21 = Ridge(alpha=1., fit_intercept=False)
lrl21.fit(lda1.theta, y)
y_lrl21 = lrl21.predict(thetas_test_lda1)
rmse(y_test, y_lrl21)


Out[51]:
1.7266314501798159

Gradient boosted regression trees


In [52]:
gbr1 = GradientBoostingRegressor()
gbr1.fit(lda1.theta, y)
y_gbr1 = gbr1.predict(thetas_test_lda1)
rmse(y_test, y_gbr1)


Out[52]:
1.790809644071969

Unregularized linear regression with SLDA topics


In [53]:
lr1_0 = LinearRegression(fit_intercept=False)
lr1_0.fit(slda1.theta, y)
y_lr1_0 = lr1_0.predict(thetas_test_slda1)
rmse(y_test, y_lr1_0)


Out[53]:
1.2424204767358347

L2-regularized linear regression with SLDA topics


In [54]:
lrl21_0 = Ridge(alpha=1., fit_intercept=False)
lrl21_0.fit(slda1.theta, y)
y_lrl21_0 = lrl21_0.predict(thetas_test_slda1)
rmse(y_test, y_lrl21_0)


Out[54]:
1.2497459065224035

Gradient boosted regression trees with SLDA topics


In [55]:
gbr1_0 = GradientBoostingRegressor()
gbr1_0.fit(slda1.theta, y)
y_gbr1_0 = gbr1_0.predict(thetas_test_slda1)
rmse(y_test, y_gbr1_0)


Out[55]:
1.3036177236719277

Conclusion for test with fewer topics

SLDA performs at least 27.5% better in root mean square error than any of the LDA-and-a-regression methods, and similar to the SLDA-and-a-linear-regression methods.