In this notebook, we will learn how to visualize topic clusters using dendrogram. Dendrogram is a tree-structured graph which can be used to visualize the result of a hierarchical clustering calculation. Hierarchical clustering puts individual data points into similarity groups, without prior knowledge of groups. We can use it to explore the topic models and see how the topics are connected to each other in a sequence of successive fusions or divisions that occur in the clustering process.
In [ ]:
!pip install plotly>=2.0.16 # 2.0.16 need for support 'hovertext' argument from create_dendrogram function
In [1]:
from gensim.models.ldamodel import LdaModel
from gensim.corpora import Dictionary
from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation
import numpy as np
import pandas as pd
import re
import plotly.offline as py
import plotly.graph_objs as go
import plotly.figure_factory as ff
py.init_notebook_mode()
We'll use the fake news dataset from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this notebook also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results.
In [2]:
df_fake = pd.read_csv('fake.csv')
df_fake[['title', 'text', 'language']].head()
df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')]
# remove stopwords and punctuations
def preprocess(row):
return strip_punctuation(remove_stopwords(row.lower()))
df_fake['text'] = df_fake['text'].apply(preprocess)
# Convert data to required input format by LDA
texts = []
for line in df_fake.text:
lowered = line.lower()
words = re.findall(r'\w+', lowered, flags=re.UNICODE|re.LOCALE)
texts.append(words)
# Create a dictionary representation of the documents.
dictionary = Dictionary(texts)
# Filter out words that occur less than 2 documents, or more than 30% of the documents.
dictionary.filter_extremes(no_below=2, no_above=0.4)
# Bag-of-words representation of the documents.
corpus_fake = [dictionary.doc2bow(text) for text in texts]
In [ ]:
lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto')
lda_fake.save('lda_35')
In [3]:
lda_fake = LdaModel.load('lda_35')
In [4]:
from gensim.matutils import jensen_shannon
from scipy import spatial as scs
from scipy.cluster import hierarchy as sch
from scipy.spatial.distance import pdist, squareform
# get topic distributions
topic_dist = lda_fake.state.get_lambda()
# get topic terms
num_words = 300
topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]
# no. of terms to display in annotation
n_ann_terms = 10
# use Jensen-Shannon distance metric in dendrogram
def js_dist(X):
return pdist(X, lambda u, v: jensen_shannon(u, v))
# define method for distance calculation in clusters
linkagefun=lambda x: sch.linkage(x, 'single')
# calculate text annotations
def text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun):
# get dendrogram hierarchy data
linkagefun = lambda x: sch.linkage(x, 'single')
d = js_dist(topic_dist)
Z = linkagefun(d)
P = sch.dendrogram(Z, orientation="bottom", no_plot=True)
# store topic no.(leaves) corresponding to the x-ticks in dendrogram
x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)
x_topic = dict(zip(P['leaves'], x_ticks))
# store {topic no.:topic terms}
topic_vals = dict()
for key, val in x_topic.items():
topic_vals[val] = (topic_terms[key], topic_terms[key])
text_annotations = []
# loop through every trace (scatter plot) in dendrogram
for trace in P['icoord']:
fst_topic = topic_vals[trace[0]]
scnd_topic = topic_vals[trace[2]]
# annotation for two ends of current trace
pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]
neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]
pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]
neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]
t1 = "<br>".join((": ".join(("+++", str(pos_tokens_t1))), ": ".join(("---", str(neg_tokens_t1)))))
t2 = t3 = ()
t4 = "<br>".join((": ".join(("+++", str(pos_tokens_t4))), ": ".join(("---", str(neg_tokens_t4)))))
# show topic terms in leaves
if trace[0] in x_ticks:
t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])
if trace[2] in x_ticks:
t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])
text_annotations.append([t1, t2, t3, t4])
# calculate intersecting/diff for upper level
intersecting = fst_topic[0] & scnd_topic[0]
different = fst_topic[0].symmetric_difference(scnd_topic[0])
center = (trace[0] + trace[2]) / 2
topic_vals[center] = (intersecting, different)
# remove trace value after it is annotated
topic_vals.pop(trace[0], None)
topic_vals.pop(trace[2], None)
return text_annotations
In [5]:
# get text annotations
annotation = text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun)
# Plot dendrogram
dendro = ff.create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), linkagefun=linkagefun, hovertext=annotation)
dendro['layout'].update({'width': 1000, 'height': 600})
py.iplot(dendro)
The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the "root" of the tree) is related to their similarity. For ex., topic 4 and 30 are more similar to each other than to topic 32. In addition, topic 18 and 24 are more similar to 35 than topic 4 and 30 are to topic 32 as the height on which they merge is lower than the merge height of 4/30 to 32.
Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node.
This type of tree graph could help us see the high level cluster theme that might exist in our data as we can see the common/different terms of combined topics in a cluster head annotation.
In [6]:
# get text annotations
annotation = text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun)
# Initialize figure by creating upper dendrogram
figure = ff.create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), linkagefun=linkagefun, hovertext=annotation)
for i in range(len(figure['data'])):
figure['data'][i]['yaxis'] = 'y2'
In [7]:
# get distance matrix and it's topic annotations
mdiff, annotation = lda_fake.diff(lda_fake, distance="jensen_shannon", normed=False)
# get reordered topic list
dendro_leaves = figure['layout']['xaxis']['ticktext']
dendro_leaves = [x - 1 for x in dendro_leaves]
# reorder distance matrix
heat_data = mdiff[dendro_leaves, :]
heat_data = heat_data[:, dendro_leaves]
In [8]:
# heatmap annotation
annotation_html = [["+++ {}<br>--- {}".format(", ".join(int_tokens), ", ".join(diff_tokens))
for (int_tokens, diff_tokens) in row] for row in annotation]
# plot heatmap of distance matrix
heatmap = go.Data([
go.Heatmap(
z=heat_data,
colorscale='YIGnBu',
text=annotation_html,
hoverinfo='x+y+z+text'
)
])
heatmap[0]['x'] = figure['layout']['xaxis']['tickvals']
heatmap[0]['y'] = figure['layout']['xaxis']['tickvals']
# Add Heatmap Data to Figure
figure['data'].extend(heatmap)
dendro_leaves = [x + 1 for x in dendro_leaves]
# Edit Layout
figure['layout'].update({'width': 800, 'height': 800,
'showlegend':False, 'hovermode': 'closest',
})
# Edit xaxis
figure['layout']['xaxis'].update({'domain': [.25, 1],
'mirror': False,
'showgrid': False,
'showline': False,
"showticklabels": True,
"tickmode": "array",
"ticktext": dendro_leaves,
"tickvals": figure['layout']['xaxis']['tickvals'],
'zeroline': False,
'ticks': ""})
# Edit yaxis
figure['layout']['yaxis'].update({'domain': [0, 0.75],
'mirror': False,
'showgrid': False,
'showline': False,
"showticklabels": True,
"tickmode": "array",
"ticktext": dendro_leaves,
"tickvals": figure['layout']['xaxis']['tickvals'],
'zeroline': False,
'ticks': ""})
# Edit yaxis2
figure['layout'].update({'yaxis2':{'domain': [0.75, 1],
'mirror': False,
'showgrid': False,
'showline': False,
'zeroline': False,
'showticklabels': False,
'ticks': ""}})
py.iplot(figure)
The heatmap lets us see the exact distance measure between any two topics in the z-value of their corresponding cell and also their intersecting or different terms in the +++/--- annotation. This could help see the distance between those topics also which are not directly connected in the dendrogram.