This notebook is a modified version of the one created by Jake Vanderplas for PyCon 2015.

Source and license info of the original notebook are on GitHub.</i></small>


In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

# use seaborn plotting defaults
import seaborn as sns; sns.set()

1. Introducing K-Means

K Means is an algorithm for unsupervised clustering: that is, finding clusters in data based on the data attributes alone (not the labels).

K Means is a relatively easy-to-understand algorithm. It searches for cluster centers which are the mean of the points within them, such that every point is closest to the cluster center it is assigned to.

Let's look at how KMeans operates on a synthetic example. To emphasize that this is unsupervised, we do not plot the colors of the clusters:


In [2]:
from sklearn.datasets.samples_generator import make_blobs
X, y = make_blobs(n_samples=300, centers=4,
                  random_state=0, cluster_std=0.60)
plt.scatter(X[:, 0], X[:, 1], s=50);


By eye, it is relatively easy to pick out the four clusters. If you were to perform an exhaustive search for the different segmentations of the data, however, the search space would be exponential in the number of points. Fortunately, the $K$-Means algorithm implemented in Scikit-learn provides a much more convenient solution.

Exercise:

The following frament of code runs the K-means method on the toy example you just created. Modify it, so that you can try other settings for the parameter options implemented by the method. In particular:

  • Reduce the number of runs to check the consequences of a bad initialization
  • Test different kinds of initializations (k-means++ vs random)
  • Provide a user-generated initialization that you consider can result in very suboptimal performance
  • Test other selections of the number of parameters
  • Include in the plot the location of the cluster of each class

In [3]:
from sklearn.cluster import KMeans
est = KMeans(4)  # 4 clusters
est.fit(X)
y_kmeans = est.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='rainbow');


The K-Means Algorithm: Interactive visualization

The following fragment of code allows you to study the evolution of cluster centroids on one run of the algorithm, and to modify also the number of centroids.


In [4]:
from fig_code import plot_kmeans_interactive
plot_kmeans_interactive();


2. Application of KMeans to Digits

For a closer-to-real-world example, let us take a look at a digit recognition dataset. Here we'll use KMeans to automatically cluster the data in 64 dimensions, and then look at the cluster centers to see what the algorithm has found.


In [5]:
from sklearn.datasets import load_digits
digits = load_digits()

print 'Input data and label number are provided in the following two variables:'
print "digits['images']: " + str(digits['images'].shape)
print "digits['target']: " + str(digits['target'].shape)


Input data and label number are provided in the following two variables:
digits['images']: (1797, 8, 8)
digits['target']: (1797,)

Next, we cluster the data into 10 groups, and plot the representatives (centroids of each group). As with the toy example, you could modify the initialization settings to study the impact of initialization in the performance of the method


In [6]:
est = KMeans(n_clusters=10)
clusters = est.fit_predict(digits.data)
est.cluster_centers_.shape


Out[6]:
(10, 64)

In [7]:
fig = plt.figure(figsize=(8, 3))
for i in range(10):
    ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
    ax.imshow(est.cluster_centers_[i].reshape((8, 8)), cmap=plt.cm.binary)


We see that even without the labels, KMeans is able to find clusters whose means are recognizable digits (with apologies to the number 8)!

For good measure, let us extract the two most relevant features (using PCA) and look at the true cluster labels and K-means cluster labels:


In [8]:
from sklearn.decomposition import PCA

X = PCA(2).fit_transform(digits.data)

kwargs = dict(cmap = plt.cm.get_cmap('rainbow', 10),
              edgecolor='none', alpha=0.6)
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].scatter(X[:, 0], X[:, 1], c=est.labels_, **kwargs)
ax[0].set_title('learned cluster labels')

ax[1].scatter(X[:, 0], X[:, 1], c=digits.target, **kwargs)
ax[1].set_title('true labels');


Just for kicks, let us see how accurate our K-Means classifier is with no label information. In order to do so, we can work on the confussion matrix:


In [9]:
from sklearn.metrics import confusion_matrix

conf = confusion_matrix(digits.target, est.labels_)
print(conf)

plt.imshow(conf,
           cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.grid(False)
plt.ylabel('true')
plt.xlabel('Group index');

#And compute the number of right guesses if each identified group were assigned to the right class
print 'Percentage of patterns that would be correctly classified: ' \
            + str(np.sum(np.max(conf,axis=1)) * 100. / np.sum(conf)) + '%'


[[  0 177   0   0   0   1   0   0   0   0]
 [ 99   0   1   1  24   0   0  55   2   0]
 [  8   1   0  13 148   0   2   2   0   3]
 [  7   0   2 154   0   0  13   0   0   7]
 [  3   0   0   0   0 163   0   7   0   8]
 [  0   0 136   1   0   2  42   0   1   0]
 [  2   1   0   0   0   0   0   1 177   0]
 [  2   0   0   0   0   0   0   2   0 175]
 [101   0   4   2   3   0  52   5   2   5]
 [  2   0   6   6   0   0 139  20   0   7]]
Percentage of patterns that would be correctly classified: 81.7473567056%

This is above 80% classification accuracy for an entirely unsupervised estimator which knew nothing about the labels.

3. Example: KMeans for Color Compression

One interesting application of clustering is in color image compression. For example, imagine you have an image with millions of colors. In most images, a large number of the colors will be unused, and conversely a large number of pixels will have similar or identical colors.

Scikit-learn has a number of images that you can play with, accessed through the datasets module. For example:


In [10]:
from sklearn.datasets import load_sample_image
china = load_sample_image("china.jpg")
plt.imshow(china)
plt.grid(False);


The image itself is stored in a 3-dimensional array, of size (height, width, RGB). For each pixel three values are necessary, each in the range 0 to 255. This means that each pixel is stored using 24 bits.


In [11]:
print 'The image dimensios are ' + str(china.shape)
print 'The RGB values of pixel 2 x 2 are ' + str(china[2,2,:])


The image dimensios are (427, 640, 3)
The RGB values of pixel 2 x 2 are [174 201 231]

We can envision this image as a cloud of points in a 3-dimensional color space. We'll rescale the colors so they lie between 0 and 1, then reshape the array to be a typical scikit-learn input:


In [12]:
X = (china / 255.0).reshape(-1, 3)
print(X.shape)


(273280, 3)

We now have 273,280 points in 3 dimensions.

Our task is to use KMeans to compress the $256^3$ colors into a smaller number (say, 64 colors). Basically, we want to find $N_{color}$ clusters in the data, and create a new image where the true input color is replaced by the color of the closest cluster. Compressing data in this way, each pixel will be represented using only 6 bits (25 % of the original image size)


In [13]:
# reduce the size of the image for speed. Only for the K-means algorithm
image = china[::3, ::3]
n_colors = 64

X = (image / 255.0).reshape(-1, 3)
    
model = KMeans(n_colors)
model.fit(X)
labels = model.predict((china / 255.0).reshape(-1, 3))
#print labels.shape
colors = model.cluster_centers_
new_image = colors[labels].reshape(china.shape)
new_image = (255 * new_image).astype(np.uint8)

#For comparison purposes, we pick 64 colors at random
perm = np.random.permutation(range(X.shape[0]))[:n_colors]
colors = X[perm,:]

from scipy.spatial.distance import cdist
labels = np.argmin(cdist((china / 255.0).reshape(-1, 3),colors),axis=1)
new_image_rnd = colors[labels].reshape(china.shape)
new_image_rnd = (255 * new_image_rnd).astype(np.uint8)

# create and plot the new image
with sns.axes_style('white'):
    plt.figure()
    plt.imshow(china)
    plt.title('Original image')

    plt.figure()
    plt.imshow(new_image)
    plt.title('{0} colors'.format(n_colors))
    
    plt.figure()
    plt.imshow(new_image_rnd)
    plt.title('{0} colors'.format(n_colors) + ' (random selection)')


Compare the input and output image: we've reduced the $256^3$ colors to just 64. An additional image is created by selecting 64 colors at random from the original image. Try reducing the number of colors to 32, 16, 8, and compare the images in these cases.


In [ ]: