Clustering using K-means

Let's jump right in. You will write your own K-means algorithm, and test it on some synthetic examples.

As a very brief reminder, the K-means algorithm iteratively hard-assigns data points to clusters, and then refines the cluster centers based on the assigned data points. The code will thus end up computing distances from all the data points to the cluster centers, finding the closest distance to do the hard-assignment, and then recomputing the cluster centroids.


In [1]:
# Boilerplate setup!
%matplotlib inline
import numpy as np
import pylab as plt
# Some utility code for this notebook
import utils

In [2]:
# Let's generate some synthetic data
(amps,means,varis),ax = utils.get_clusters_A()
Xi = utils.sample_clusters(amps, means, varis)
X = np.vstack(Xi)

# Plot the true cluster memberships
plt.clf()
for i,x in enumerate(Xi):
    plt.plot(x[:,0], x[:,1], 'o', color=utils.colors[i])
plt.title('True cluster labels');



In [3]:
# Now we're going to define our K-means algorithm
# There is a spot in here where you need to fill in some code!

def k_means(X, K, diagnostics=None):
    '''
    K-means clustering algorithm.
    
    *X*: (N, D) array of data points
    *K*: integer number of clusters to find
    *diagnostics*: optional function to run each round:
        diagnostic(step, X, K, centroids, newcentroids, nearest)

    Returns: *centroids*: (K,D) array of cluster centroids
    '''

    N,D = X.shape
    
    # Initialize randomly: draw K integers (no duplicates)
    # using permutation() is overkill, but easy
    I = np.random.permutation(N)[:K]
    centroids = X[I]

    # Loop until convergence... or 20 steps in case you mess up!
    for step in range(20):
        # Compute the distance from each data point in X
        # to each cluster centroid.
        distances = utils.distance_matrix(X, centroids)
        # 'distances' has shape (N, K)
        
        # Find the nearest cluster centroid for each data point
        nearest = np.argmin(distances, axis=1)
        # 'nearest' has shape (N)

        # Compute the new centroids... they will go in this array:
        newcentroids = np.zeros((K, D))
        
        ##### ADD CODE HERE to compute 'newcentroids' #####
        # You want the new centroid for cluster k to be the mean of all
        # the data points that are nearest to the current cluster k.

        # HERE's my SOLUTION:
        for k in range(K):
            I = np.flatnonzero(nearest == k)
            newcentroids[k,:] = np.mean(X[I,:], axis=0)
        
        # Note:
        # data point i is at x,y = (X[i,0], X[i,1]) and is closest
        # to cluster centroid index nearest[i].
        
        # Make diagnostic plot for this round?
        if diagnostics is not None:
            diagnostics(step, X, K, centroids, newcentroids, nearest)

        # Are we converged?  This is a hacky test.
        if np.max(np.abs(newcentroids - centroids)) < 1e-8:
            centroids = newcentroids
            break

        centroids = newcentroids

    return centroids

In [4]:
# Here we're actually going to run the function we defined above.
# Try playing around with the value of K to see what happens.
# You could also play around with different random initializations
# (via np.random.seed(42) or your favorite number)
K = 3

# You can call the plotting routine each iteration to see what
# happens as the algorithm proceeds:
# centroids = k_means(X, K, diagnostics=utils.plot_kmeans)
centroids = k_means(X, K)

utils.plot_kmeans(0, X, K, centroids, centroids, None, show=False)
plt.title('K-means result');



In [5]:
# Now let's test out some harder datasets and see what happens

# This dataset has many more members in one cluster than the other
(amps,means,varis),ax = utils.get_clusters_C()
Xi = utils.sample_clusters(amps, means, varis, N=500)
X = np.vstack(Xi)

# Plot the true cluster memberships
plt.clf()
for i,x in enumerate(Xi):
    plt.plot(x[:,0], x[:,1], 'o', color=utils.colors[i])
plt.title('True cluster labels');
plt.show();

K = 2
centroids = k_means(X, K)
utils.plot_kmeans(0, X, K, centroids, centroids, None, show=False)
plt.title('K-means result');



In [ ]: