This is a short notebook describing the Minimum Spanning Tree (MST) Clustering algorithm implemented at http://github.com/jakevdp/mst_clustering/. The API of the clustering estimator is compatible with scikit-learn.
Note that MST clustering is effectively identical to single-linkage Agglomerative Clustering, but uses a top-down approach rather than a bottom-up approach. Similar methods are implemented in scikit-learn's Agglomerative Clustering estimator.
We'll start with some initial imports and definitions; you can ignore this for now:
In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
# matplotlib 1.4 + numpy 1.10 produces warnings; we'll filter these
import warnings; warnings.filterwarnings('ignore', message='elementwise')
def plot_mst(model, cmap='rainbow'):
"""Utility code to visualize a minimum spanning tree"""
X = model.X_fit_
fig, ax = plt.subplots(1, 2, figsize=(16, 6), sharex=True, sharey=True)
for axi, full_graph, colors in zip(ax, [True, False], ['lightblue', model.labels_]):
segments = model.get_graph_segments(full_graph=full_graph)
axi.plot(segments[0], segments[1], '-k', zorder=1, lw=1)
axi.scatter(X[:, 0], X[:, 1], c=colors, cmap=cmap, zorder=2)
axi.axis('tight')
ax[0].set_title('Full Minimum Spanning Tree', size=16)
ax[1].set_title('Trimmed Minimum Spanning Tree', size=16);
In [2]:
from sklearn.datasets import make_blobs
X, y = make_blobs(200, centers=4, random_state=42)
plt.scatter(X[:, 0], X[:, 1], c='lightblue');
By eye we can see that it contains four well-separated clusters, and a clustering algorithm should be able to find these.
The MSTClustering
estimator can accomplish this quite well:
In [3]:
from mst_clustering import MSTClustering
model = MSTClustering(cutoff_scale=2, approximate=False)
labels = model.fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='rainbow');
Notice that the algorithm finds five clusters: the one outlier near the yellow cluster is put in its own group.
This clustering result is computed by first constructing a graph from the input data, with nodes given by points in the dataset and edges given by distances between pairs of points. Then we find a Minimum Spanning Tree (MST) over this graph: an MST is a subgraph that connects all nodes such that the sum of the graph edges is minimized.
We can see a representation of the minimum spanning tree over this data in the left panel of the following plot:
In [4]:
plot_mst(model)
The right panel shows the clusters derived from this minimum spanning tree, by removing all graph edges larger than the specified cutoff_scale
parameter (here, cutoff_scale = 2
).
Points which remain connected after the truncation are given the same label.
MST clustering has several advantages over other clustering algorithms:
Many of these advantages are shared with, e.g. DBSCAN which uses a similar but conceptually distinct model.
In [5]:
rng = np.random.RandomState(int(100 * y[-1]))
noise = -14 + 28 * rng.rand(200, 2)
X_noisy = np.vstack([X, noise])
y_noisy = np.concatenate([y, np.full(200, -1, dtype=int)])
plt.scatter(X_noisy[:, 0], X_noisy[:, 1], c='lightblue', cmap='spectral_r')
plt.xlim(-15, 15)
plt.ylim(-15, 15);
We can see that there are four distinct overdensities in this data. Let's take a look at what the minimum spanning tree does in this case:
In [6]:
noisy_model = MSTClustering(cutoff_scale=1)
noisy_model.fit(X_noisy)
plot_mst(noisy_model)
Cutting the long edges does separate the four large clusters, but it also leads to a large number of isolated clusters with just a handful of points.
We can easily filter these out by specifying a minimum cluster size; all points which don't belong to a large enough cluster are grouped together with the label -1
:
In [7]:
noisy_model = MSTClustering(cutoff_scale=1, min_cluster_size=10)
noisy_model.fit(X_noisy)
plot_mst(noisy_model, cmap='spectral_r')
Thus, by appropriately tuning the cutoff scale and the minimum number of points per cluster, you can recover even clusters within a noisy background.
There are methods in the literature that discuss how to tune these parameters automatically; I hope to implement some of them in this package in the future.