Here is the collapsed Gibbs sampler for Blei and McAuliffe's supervised topic models. I am building on the collapsed Gibbs sampler I wrote for latent Dirichlet allocation.
The generative model for is as follows:
$$\begin{align} \theta^{(d)} &\sim \dir(\alpha) &\text{(topic distribution for document $d \in \{1, \ldots, D\}$)} \\ \phi^{(k)} &\sim \dir(\beta) &\text{(term distribution for topic $k \in \{1, \ldots, K\}$)} \\ z_n^{(d)} \mid \theta^{(d)} &\sim \dis \left( \theta^{(d)} \right) &\text{(topic of $n$th token of document $d$, $n \in \{1, \ldots, N^{(d)}\}$)} \\ w_n^{(d)} \mid \phi^{(z_n^{(d)})} &\sim \dis \left( \phi^{(z_n^{(d)})} \right) &\text{(term of $n$th token of document $d$, $n \in \{1, \ldots, N^{(d)}\}$)} \\ \eta_k &\sim \normal \left( \mu, \nu^2 \right) &\text{(regression coefficient for topic $k \in \{1, \ldots, K\}$)} \\ y^{(d)} \mid \eta, \etd{d} &\sim \normal \left( \eta \cdot \etd{d}, \sigma^2 \right) &\text{(response value of document $d \in \{1, \ldots, D\}$)} \end{align}$$where each token can be any one of $V$ terms in our vocabulary, and $\etd{d}$ is the empirical topic distribution of document $d$.
Plate notation for supervised latent Dirichlet allocation.
This diagram should replace $\beta_k$ with $\phi^{(k)}$, and each $\phi^{(k)}$ should be dependent on a single $\beta$.
The joint probability distribution can be factored as follows:
$$\begin{align} \cp{\theta, \phi, z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} &= \prod_{k=1}^{K} \cp{\phi^{(k)}}{\beta} \prod_{d=1}^{D} \cp{\theta^{(d)}}{\alpha} \prod_{n=1}^{N^{(d)}} \cp{z_n^{(d)}}{\theta^{(d)}} \cp{w_n^{(d)}}{\phi^{(z_n^{(d)})}} \\ & \quad \times \prod_{k'=1}^{K} \cp{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^D \cp{y^{(d')}}{\eta, \etd{d'}, \sigma^2} \\ &= \prod_{k=1}^{K} \frac{\Betaf(b^{(k)} + \beta)}{\Betaf(\beta)} \cp{\phi^{(k)}}{b^{(k)} + \beta} \prod_{d=1}^{D} \frac{\Betaf(a^{(d)} + \alpha)}{\Betaf(\alpha)} \cp{\theta^{(d)}}{a^{(d)} + \alpha} \\ &\quad \times \prod_{k'=1}^{K} \cN{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \end{align}$$where $a_k^{(d)}$ is the number of tokens in document $d$ assigned to topic $k$, $b_v^{(k)}$ is the number of tokens equal to term $v$ and assigned to topic $k$, and $\Betaf$ is the multivariate Beta function. Marginalizing out $\theta$ and $\phi$ by integrating with respect to each $\theta^{(d)}$ and $\phi^{(k)}$ over their respective sample spaces yields
$$\begin{align} \cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} &= \prod_{k=1}^{K} \frac{\Betaf(b^{(k)} + \beta)}{\Betaf(\beta)} \prod_{d=1}^{D} \frac{\Betaf(a^{(d)} + \alpha)}{\Betaf(\alpha)} \prod_{k'=1}^{K} \cN{\eta_{k'}}{\mu, \nu^2} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \\ &= \cp{w}{z, \beta} \cp{z}{\alpha} \cp{\eta}{\mu, \nu^2} \cp{y}{\eta, z, \sigma^2}. \end{align}$$See my LDA notebook for step-by-step details of the previous two calculations.
Our goal is to calculate the posterior distribution
$$\cp{z, \eta}{w, y, \alpha, \beta, \mu, \nu^2, \sigma^2} = \frac{\cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2}} {\sum_{z'} \int \cp{z', w, \eta', y}{\alpha, \beta, \mu, \nu^2, \sigma^2} d\eta'}$$in order to infer the topic assignments $z$ and regression coefficients $\eta$ from the given term assignments $w$ and response data $y$. Since calculating this directly is infeasible, we resort to collapsed Gibbs sampling. The sampler is "collapsed" because we marginalized out $\theta$ and $\phi$, and will estimate them from the topic assignments $z$:
$$\hat\theta_k^{(d)} = \frac{a_k^{(d)} + \alpha_k}{\sum_{k'=1}^K \left(a_{k'}^{(d)} + \alpha_{k'} \right)},\quad \hat\phi_v^{(k)} = \frac{b_v^{(k)} + \beta_v}{\sum_{v'=1}^V \left(b_{v'}^{(k)} + \beta_{v'} \right)}.$$Gibbs sampling requires us to compute the full conditionals for each $z_n^{(d)}$ and $\eta_k$, i.e. we need to calculate, for all $n$, $d$ and $k$,
$$\begin{align} \cp{z_n^{(d)} = k}{z \setminus z_n^{(d)}, w, \eta, y, \alpha, \beta, \mu, \nu^2, \sigma^2} &\propto \cp{z_n^{(d)} = k, z \setminus z_n^{(d)}, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} \\ &\propto \frac{b_{w_n^{(d)}}^{(k)} \setminus z_n^{(d)} + \beta_{w_n^{(d)}}}{ \sum_{v=1}^V \left( b_v^{(k)} \setminus z_n^{(d)} + \beta_v\right)} \left( a_k^{(d)} \setminus z_n^{(d)} + \alpha_k \right) \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} \\ &\propto \frac{b_{w_n^{(d)}}^{(k)} \setminus z_n^{(d)} + \beta_{w_n^{(d)}}}{ \sum_{v=1}^V \left( b_v^{(k)} \setminus z_n^{(d)} + \beta_v\right)} \left( a_k^{(d)} \setminus z_n^{(d)} + \alpha_k \right) \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \end{align}$$where the "set-minus" notation $\cdot \setminus z_n^{(d)}$ denotes the variable the notation is applied to with the entry $z_n^{(d)}$ removed (again, see my LDA notebook for details). This final proportionality is true since
$$\begin{align} \prod_{d'=1}^{D} \cN{y^{(d')}}{\eta \cdot \etd{d'}, \sigma^2} &\propto \prod_{d'=1}^{D} \exp \left( -\frac{ \left( y^{(d')} - \eta \cdot \etd{d'} \right)^2 }{2 \sigma^2} \right) \\ &\propto \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \eta \cdot \etd{d'} - \left( \eta \cdot \etd{d'} \right)^2 }{2 \sigma^2} \right) \\ &= \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) + \delta_{d, d'} \frac{\eta_k}{N^{(d)}} \right) - \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) + \delta_{d, d'} \frac{\eta_k}{N^{(d)}} \right)^2 }{2 \sigma^2} \right) \\ &= \prod_{d'=1}^{D} \exp \left( \frac{ 2 y^{(d')} \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) - \left( \eta \cdot \left( \etd{d'} \setminus z_n^{(d)} \right) \right)^2 }{2 \sigma^2} \right) \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \\ &\propto \exp \left( \frac{1}{2 \sigma^2} \frac{\eta_k}{N^{(d)}} \left( 2 \left[ y^{(d)} - \eta \cdot \left( \etd{d} \setminus z_n^{(d)} \right) \right] - \frac{\eta_k}{N^{(d)}} \right) \right) \end{align}$$where $\delta_{d, d'}$ is the Kronecker delta.
We also need to calculate the full conditional for $\eta$. In order to do this, let $Z = (\etd{1} \cdots \etd{D})$ be the matrix whose columns are the empirical topic distributions $\etd{d}$, let $I$ be the identity matrix and $\one$ be the vector of ones, and note that
$$\prod_{k=1}^{K} \cN{\eta_{k}}{\mu, \nu^2} = \cN{\eta}{\mu \one, \nu^2 I}$$$$\prod_{d=1}^{D} \cN{y^{(d)}}{\eta \cdot \etd{d'}, \sigma^2} = \cN{y}{Z^T \eta, \sigma^2 I}.$$Therefore
$$\begin{align} \cp{\eta}{z, w, y, \alpha, \beta, \mu, \nu^2, \sigma^2} &\propto \cp{z, w, \eta, y}{\alpha, \beta, \mu, \nu^2, \sigma^2} \\ &\propto \cN{\eta}{\mu \one, \nu^2 I} \cN{y}{Z^T \eta, \sigma^2 I} \\ &\propto \cN{\eta}{\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right), \Sigma} \end{align}$$where $\Sigma^{-1} = \nu^{-2} I + \sigma^{-2} ZZ^T$ (see Section 9.3 of Kevin Murphy's notes for a derivation of Bayes rule for linear Gaussian systems). It is interesting to consider the mean and variance of $\eta$ in the two variance regimes $\sigma \gg \nu$ and $\sigma \ll \nu$. If $\sigma \gg \nu$, then
$$\Sigma^{-1} = \nu^{-2} \left( I + \left( \frac{\nu}{\sigma} \right)^2 ZZ^T \right) \approx \nu^{-2} I$$which implies that the covariance structure of $\eta$ is $\Sigma \approx \nu^2 I$ and the mean of $\eta$ is
$$\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right) \approx \mu \one + \left( \frac{\nu}{\sigma} \right)^2 Zy \approx \mu \one,$$thus $\eta$ is approximately distributed according to its prior distribution. On the other hand, if $\sigma \ll \nu$, then
$$\Sigma^{-1} = \sigma^{-2} \left( \left( \frac{\sigma}{\nu} \right)^2 I + ZZ^T \right) \approx \sigma^{-2} ZZ^T.$$Notice that $ZZ^T$ is almost surely positive definite, and hence almost surely invertible. Therefore $\Sigma \approx \sigma^2 (ZZ^T)^{-1}$ and
$$\Sigma \left( \nu^{-2} \mu \one + \sigma^{-2} Zy \right) \approx \left( \frac{\sigma}{\nu} \right)^2 \mu (ZZ^T)^{-1} \one + (ZZ^T)^{-1} Zy \approx (ZZ^T)^{-1} Zy,$$thus $\eta$ is approximately distributed as the least-squares solution of $y = Z^T \eta$.
In [2]:
%matplotlib inline
from modules.helpers import plot_images
from functools import partial
from sklearn.metrics import (mean_squared_error)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
imshow = partial(plt.imshow, cmap='gray', interpolation='nearest', aspect='auto')
rmse = lambda y_true, y_pred: np.sqrt(mean_squared_error(y_true, y_pred))
sns.set(style='white')
In [3]:
V = 25
K = 10
N = 100
D = 1000
In [4]:
topics = []
topic_base = np.concatenate((np.ones((1, 5)) * 0.2, np.zeros((4, 5))), axis=0).ravel()
for i in range(5):
topics.append(np.roll(topic_base, i * 5))
topic_base = np.concatenate((np.ones((5, 1)) * 0.2, np.zeros((5, 4))), axis=1).ravel()
for i in range(5):
topics.append(np.roll(topic_base, i))
topics = np.array(topics)
plt.figure(figsize=(10, 5))
plot_images(plt, topics, (5, 5), layout=(2, 5), figsize=(10, 5))
In [5]:
alpha = np.ones(K)
np.random.seed(42)
thetas = np.random.dirichlet(alpha, size=D)
topic_assignments = np.array([np.random.choice(range(K), size=100, p=theta)
for theta in thetas])
word_assignments = np.array([[np.random.choice(range(V), size=1, p=topics[topic_assignments[d, n]])[0]
for n in range(N)] for d in range(D)])
doc_term_matrix = np.array([np.histogram(word_assignments[d], bins=V, range=(0, V - 1))[0] for d in range(D)])
imshow(doc_term_matrix)
Out[5]:
In [6]:
# choose parameter values
nu2 = 10
sigma2 = 1
np.random.seed(42)
eta = np.random.normal(scale=nu2, size=K)
y = [np.dot(eta, thetas[i]) for i in range(D)] + np.random.normal(scale=sigma2, size=D)
# plot histogram of responses
print(eta)
_ = plt.hist(y, bins=20)
In [7]:
from slda.topic_models import SLDA
In [9]:
_K = 10
_alpha = alpha
_beta = np.repeat(0.01, V)
_mu = 0
_nu2 = nu2
_sigma2 = sigma2
n_iter = 500
slda = SLDA(_K, _alpha, _beta, _mu, _nu2, _sigma2, n_iter, seed=42)
In [10]:
%%time
slda.fit(doc_term_matrix, y)
In [11]:
plot_images(plt, slda.phi, (5, 5), (2, 5), figsize=(10, 5))
print(slda.phi)
print(np.sum(slda.phi, axis=0))
print(np.sum(slda.phi, axis=1))
In [12]:
topic_order = [1, 2, 4, 3, 9, 0, 7, 5, 6, 8]
plot_images(plt, slda.phi[topic_order], (5, 5), (2, 5), figsize=(10, 5))
In [13]:
imshow(slda.theta)
Out[13]:
In [14]:
plt.plot(slda.loglikelihoods)
Out[14]:
In [15]:
plt.plot(np.diff(slda.loglikelihoods)[-100:])
Out[15]:
In [13]:
burn_in = max(n_iter - 100, int(n_iter / 2))
slda.loglikelihoods[burn_in:].mean()
Out[13]:
In [14]:
eta_pred = slda.eta[burn_in:].mean(axis=0)
print(eta)
print(eta_pred[topic_order])
np.linalg.norm(eta - eta_pred[topic_order])
Out[14]:
Create 1,000 test documents using the same generative process as our training documents, and compute their actual responses.
In [15]:
np.random.seed(42^2)
thetas_test = np.random.dirichlet(np.ones(K), size=D)
topic_assignments_test = np.array([np.random.choice(range(K), size=100, p=theta)
for theta in thetas_test])
word_assignments_test = np.array([[np.random.choice(range(V), size=1, p=topics[topic_assignments_test[d, n]])[0]
for n in range(N)] for d in range(D)])
doc_term_matrix_test = np.array([np.histogram(word_assignments_test[d], bins=V, range=(0, V - 1))[0] for d in range(D)])
y_test = [np.dot(eta, thetas_test[i]) for i in range(D)]
imshow(doc_term_matrix_test)
Out[15]:
Estimate their topic distributions using the trained model, then calculate the predicted responses using the mean of our samples of $\eta$ after burn-in as an estimate for $\eta$.
In [16]:
thetas_test_slda = slda.transform(doc_term_matrix_test)
y_slda = [np.dot(eta_pred, thetas_test_slda[i]) for i in range(D)]
Measure the goodness of our prediction using root mean square error.
In [17]:
rmse(y_test, y_slda)
Out[17]:
In [18]:
from slda.topic_models import LDA
In [19]:
lda = LDA(_K, _alpha, _beta, n_iter, seed=42)
In [20]:
%%time
lda.fit(doc_term_matrix)
In [21]:
plot_images(plt, lda.phi, (5, 5), (2, 5), figsize=(10, 5))
In [22]:
topic_order_lda = [1, 2, 0, 3, 9, 4, 7, 5, 6, 8]
plot_images(plt, lda.phi[topic_order_lda], (5, 5), (2, 5), figsize=(10, 5))
In [23]:
imshow(lda.theta)
Out[23]:
In [24]:
plt.plot(lda.loglikelihoods)
Out[24]:
In [25]:
thetas_test_lda = lda.transform(doc_term_matrix_test)
In [26]:
from sklearn.linear_model import LinearRegression
In [27]:
lr = LinearRegression(fit_intercept=False)
lr.fit(lda.theta, y)
y_lr = lr.predict(thetas_test_lda)
rmse(y_test, y_lr)
Out[27]:
In [28]:
print(eta)
print(lr.coef_[topic_order_lda])
np.linalg.norm(eta - lr.coef_[topic_order_lda])
Out[28]:
In [29]:
from sklearn.linear_model import Ridge
In [30]:
lrl2 = Ridge(alpha=1., fit_intercept=False)
lrl2.fit(lda.theta, y)
y_lrl2 = lrl2.predict(thetas_test_lda)
rmse(y_test, y_lrl2)
Out[30]:
In [31]:
print(eta)
print(lrl2.coef_[topic_order_lda])
np.linalg.norm(eta - lrl2.coef_[topic_order_lda])
Out[31]:
In [32]:
from sklearn.ensemble import GradientBoostingRegressor
In [33]:
gbr = GradientBoostingRegressor()
gbr.fit(lda.theta, y)
y_gbr = gbr.predict(thetas_test_lda)
rmse(y_test, y_gbr)
Out[33]:
SLDA is slightly better than unregularized linear regression, and better than ridge regression or gradient boosted regression trees. The similar performance to unregularized linear regression is likely due to the fact that this test was set up as an exact problem - all parameters used in training, except $\beta$ (because we hand-picked the topics) and $\eta$ (because that was the one parameter we wanted to learn), were those used to generate documents, and in SLDA the likelihood of $y$ introduces a regularization on $\eta$ in the full-conditional for $z$, while unregularized linear regression does not enforce such a penalty.
We now redo the previous test, but this time with fewer topics, in order to determine whether
Because we will no longer be solving an exact problem (the number of topics, and hence $\alpha$, will both be different from the document generation process), we expect SLDA to do better than LDA-and-a-regression, including unregularized linear regression.
In [34]:
_K = 5
_alpha = np.repeat(1. / _K, _K)
_beta = np.repeat(0.01, V)
_mu = 0
_nu2 = nu2
_sigma2 = sigma2
n_iter = 500
slda1 = SLDA(_K, _alpha, _beta, _mu, _nu2, _sigma2, n_iter, seed=42)
In [35]:
%%time
slda1.fit(doc_term_matrix, y)
In [36]:
plot_images(plt, slda1.phi, (5, 5), (1, 5), figsize=(10, 5))
In [37]:
imshow(slda1.theta)
Out[37]:
In [38]:
plt.plot(slda1.loglikelihoods)
Out[38]:
In [39]:
burn_in1 = max(n_iter - 100, int(n_iter / 2))
slda1.loglikelihoods[burn_in1:].mean()
Out[39]:
In [40]:
eta_pred1 = slda1.eta[burn_in1:].mean(axis=0)
eta_pred1
Out[40]:
In [41]:
thetas_test_slda1 = slda1.transform(doc_term_matrix_test)
y_slda1 = [np.dot(eta_pred1, thetas_test_slda1[i]) for i in range(D)]
In [42]:
rmse(y_test, y_slda1)
Out[42]:
In [43]:
lda1 = LDA(_K, _alpha, _beta, n_iter, seed=42)
In [44]:
%%time
lda1.fit(doc_term_matrix)
In [45]:
plot_images(plt, lda1.phi, (5, 5), (1, 5), figsize=(10, 5))
We plot the SLDA topics again and note that they are indeed different!
In [46]:
plot_images(plt, slda1.phi, (5, 5), (1, 5), figsize=(10, 5))
In [47]:
imshow(lda1.theta)
Out[47]:
In [48]:
plt.plot(lda1.loglikelihoods)
Out[48]:
In [49]:
thetas_test_lda1 = lda1.transform(doc_term_matrix_test)
In [50]:
lr1 = LinearRegression(fit_intercept=False)
lr1.fit(lda1.theta, y)
y_lr1 = lr1.predict(thetas_test_lda1)
rmse(y_test, y_lr1)
Out[50]:
In [51]:
lrl21 = Ridge(alpha=1., fit_intercept=False)
lrl21.fit(lda1.theta, y)
y_lrl21 = lrl21.predict(thetas_test_lda1)
rmse(y_test, y_lrl21)
Out[51]:
In [52]:
gbr1 = GradientBoostingRegressor()
gbr1.fit(lda1.theta, y)
y_gbr1 = gbr1.predict(thetas_test_lda1)
rmse(y_test, y_gbr1)
Out[52]:
In [53]:
lr1_0 = LinearRegression(fit_intercept=False)
lr1_0.fit(slda1.theta, y)
y_lr1_0 = lr1_0.predict(thetas_test_slda1)
rmse(y_test, y_lr1_0)
Out[53]:
In [54]:
lrl21_0 = Ridge(alpha=1., fit_intercept=False)
lrl21_0.fit(slda1.theta, y)
y_lrl21_0 = lrl21_0.predict(thetas_test_slda1)
rmse(y_test, y_lrl21_0)
Out[54]:
In [55]:
gbr1_0 = GradientBoostingRegressor()
gbr1_0.fit(slda1.theta, y)
y_gbr1_0 = gbr1_0.predict(thetas_test_slda1)
rmse(y_test, y_gbr1_0)
Out[55]: