In this example, we'll perform a silhouette analysis to the clusters found by our K-Means clustering method. First we'll create an instance of K-Means, then fit it to the data. Afterwards we can pass the cluster_labels
, i.e. the output of the fit_predict
method, into the skplt.metrics.plot_silhouette
method.
In [1]:
from __future__ import absolute_import
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris as load_data
# Import scikit-plot
import scikitplot as skplt
%pylab inline
pylab.rcParams['figure.figsize'] = (14, 14)
In [2]:
# Load the data
X, y = load_data(return_X_y=True)
# Create an instance of the clusterer then fit
kmeans = KMeans(n_clusters=4, random_state=1)
cluster_labels = kmeans.fit_predict(X)
In [3]:
# Plot!
skplt.metrics.plot_silhouette(X, cluster_labels, cmap='nipy_spectral')
plt.show()