An example showing the plot_silhouette method used by a scikit-learn clusterer

In this example, we'll perform a silhouette analysis to the clusters found by our K-Means clustering method. First we'll create an instance of K-Means, then fit it to the data. Afterwards we can pass the cluster_labels, i.e. the output of the fit_predict method, into the skplt.metrics.plot_silhouette method.


In [1]:
from __future__ import absolute_import
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris as load_data

# Import scikit-plot
import scikitplot as skplt

%pylab inline
pylab.rcParams['figure.figsize'] = (14, 14)


Populating the interactive namespace from numpy and matplotlib

In [2]:
# Load the data
X, y = load_data(return_X_y=True)

# Create an instance of the clusterer then fit
kmeans = KMeans(n_clusters=4, random_state=1)
cluster_labels = kmeans.fit_predict(X)

In [3]:
# Plot!
skplt.metrics.plot_silhouette(X, cluster_labels, cmap='nipy_spectral')
plt.show()