Introducing K-Means

K Means is an algorithm for unsupervised clustering: that is, finding clusters in data based on the data attributes alone (not the labels).


In [ ]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

In [ ]:
from sklearn.datasets.samples_generator import make_blobs
X, y = make_blobs(n_samples=300, centers=4,
                  random_state=0, cluster_std=0.60)
plt.scatter(X[:, 0], X[:, 1], s=50);

The K-Means Algorithm: Lloyd’s algorithm

K-Means works as follows:

  1. Guess some cluster centers
  2. Repeat until convergence
    1. Assign points to the nearest cluster center
    2. Set the cluster centers to the mean

In [ ]:
from sklearn.cluster import KMeans
clf = KMeans(4)  # 4 clusters
clf.fit(X)
clusters = clf.predict(X)
cluster_centers = clf.cluster_centers_;
plt.scatter(X[:, 0], X[:, 1], c=clusters, s=50, cmap='rainbow');

In [ ]:
plt.scatter(X[:, 0], X[:, 1], c=clusters, s=50, cmap='rainbow');
plt.scatter(cluster_centers[:,0],cluster_centers[:,1],s= 100,linewidth=2.0,c='black')

KMeans Caveats

  1. It is possible to find a solution that is not at the global minima. Therefore, scikit-learn by default uses a large number of random initializations and finds the best results.

  2. The number of clusters must be set beforehand...

Application of KMeans to Digits

For a closer-to-real-world example, we will look at the digits data. This dataset is made up of 1797 8x8 images. Each image, like the one shown below, is of a hand-written digit (0-9).

Here we'll use KMeans to automatically cluster the data in 64 dimensions, and then look at the cluster centers to see what the algorithm has found.


In [ ]:
from sklearn.datasets import load_digits
digits = load_digits()

#Display the first digit
plt.figure(1, figsize=(3, 3))
plt.imshow(digits.images[0], cmap=plt.cm.gray_r, interpolation = 'nearest')
plt.show()

In order to utilize an 8x8 figure like this, it needs to be first transform it into a feature vector with length 64.


In [ ]:
print digits.data.shape

Perform kmeans clustering on the digits data. Set n_clusters = 10.


In [ ]:
# Your code goes here

Plot each of these cluster centers to see what they represent:


In [ ]:
fig = plt.figure(figsize=(8, 3))
for i in range(10):
    ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
    ax.imshow(clf.cluster_centers_[i].reshape((8, 8)), cmap=plt.cm.binary)

We see that even without the labels, KMeans is able to find clusters whose means are recognizable digits.


In [ ]: