In [1]:
# This notebook is used to decide on a tolerable level of corruptableness.
%matplotlib inline
import numpy as np
from scipy.stats import entropy as KL_divergence
import pandas as pd
import matplotlib.pyplot as plt
from slda.topic_models import RTM
from modules.helpers import plot_images
In [2]:
# Generate topics
# We assume a vocabulary of 'rows'^2 terms, and create 'rows'*2 "topics",
# where each topic assigns exactly 'rows' consecutive terms equal probability.
rows = 3
V = rows * rows
K = rows * 2
N = K * K
D = 10000
seed = 42
topics = []
topic_base = np.concatenate((np.ones((1, rows)) * (1/rows),
np.zeros((rows-1, rows))), axis=0).ravel()
for i in range(rows):
topics.append(np.roll(topic_base, i * rows))
topic_base = np.concatenate((np.ones((rows, 1)) * (1/rows),
np.zeros((rows, rows-1))), axis=1).ravel()
for i in range(rows):
topics.append(np.roll(topic_base, i))
topics = np.array(topics)
# Generate documents from topics
# We generate D documents from these V topics by sampling D topic
# distributions, one for each document, from a Dirichlet distribution with
# parameter α=(1,…,1)
alpha = np.ones(K)
np.random.seed(seed)
thetas = np.random.dirichlet(alpha, size=D)
topic_assignments = np.array([np.random.choice(range(K), size=N, p=theta)
for theta in thetas])
word_assignments = np.array([[np.random.choice(range(V), size=1,
p=topics[topic_assignments[d, n]])[0]
for n in range(N)] for d in range(D)])
doc_term_matrix = np.array([np.histogram(word_assignments[d], bins=V,
range=(0, V - 1))[0] for d in range(D)])
# choose parameter values
mu = 0.
sigma2 = 1.
nu = 1.
np.random.seed(14)
# Estimate parameters
_K = K
_alpha = alpha[:_K]
_beta = np.repeat(0.01, V)
_mu = mu
_sigma2 = sigma2
_nu = nu
n_iter = 1000
rtm = RTM(_K, _alpha, _beta, _mu, _sigma2, _nu, n_iter, seed=42)
rtm.fit(doc_term_matrix)
results = grtm.phi
In [ ]:
for res in results:
minimized_KL = 1
for topic in topics:
KL = KL_divergence(topic, res)
if KL < minimized_KL:
minimized_KL = KL
print(minimized_KL)
In [ ]:
plot_images(plt, results, (rows, rows), (2, rows))
plt.figure()
plt.plot(rtm.loglikelihoods)
plt.figure()
plt.plot(np.diff(rtm.loglikelihoods)[-100:])
In [ ]: