CSE 6040, Fall 2015 [28]: K-means Clustering, Part 2

Last time, we implemented the basic version of K-means. In this lecture we will explore some advanced techniques to improve the performance of K-means.


In [ ]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

Read in data


In [ ]:
df = pd.read_csv ('http://vuduc.org/cse6040/logreg_points_train.csv')
points = df.as_matrix (['x_1', 'x_2'])
labels = df['label'].as_matrix ()
n = points.shape[0]
d = points.shape[1]
k = 2

In [ ]:
df.head()

In [ ]:
def init_centers(X, k):
    sampling = np.random.randint(0, n, k)
    return X[sampling, :]

Fast implementation of the distance matrix computation

The idea is that $$||(x - c)||^2 = ||x||^2 - 2\langle x, c \rangle + ||c||^2 $$ This has many advantages.

  1. The centers are fixed (during a single iteration), so only needs to compute once
  2. Data points are usually sparse, but centers are not
  3. If implement cleverly, we don't need to use for loops

In [ ]:
def compute_d2(X, centers):
    D = np.empty((n, k))   
    for i in range(n):
        D[i, :] = np.linalg.norm(X[i,:] - centers, axis=1) ** 2
    
    return D

In [ ]:
def compute_d2_fast(X, centers):

    # @YOUSE: compute a length-n array, where each entry is the square of norm of a point
    first_term = 

    # @YOUSE: compute a (n * k) matrix, where entry (i,j) is the two times of inner product of row i of X and row j of centers
    second_term = 

    # @YOUSE: compute a length-k array, where each entry is the square of norm of a center
    third_term = 
    
    D = np.tile(first_term, (centers.shape[0], 1)).T - second_term + np.tile(third_term, (n,1))
    D[D < 0] = 0
    
    return D

Let's see the different in running time of the two implementations.


In [ ]:
centers = init_centers(points, k)
%timeit D = compute_d2(points, centers)
%timeit D = compute_d2_fast(points, centers)

In [ ]:
def cluster_points(D): 
    return np.argmin(D, axis=1)

In [ ]:
def update_centers(X, clustering):
    centers = np.empty((k, d))
    for i in range(k):
        members = (clustering == i)
        if any(members):
            centers[i, :] = np.mean(X[members, :], axis=0)
    return centers

In [ ]:
def WCSS(D):
    min_val = np.amin(D, axis=1)
    return np.sum(min_val)

In [ ]:
def has_converged(old_centers, centers):
    return set([tuple(x) for x in old_centers]) == set([tuple(x) for x in centers])

In [ ]:
def kmeans_basic(X, k):
    old_centers = init_centers(X, k)
    centers = init_centers(X, k)
    i = 1
    while not has_converged(old_centers, centers):
        old_centers = centers
        D = compute_d2_fast(X, centers)
        clustering = cluster_points(D)
        centers = update_centers(X, clustering)
        print "iteration", i, "WCSS = ", WCSS(D)
        i += 1
    return centers, clustering

In [ ]:
centers, clustering = kmeans_basic(points, k)

In [ ]:
def plot_clustering_k2(centers, clustering):
    df['clustering'] = clustering
    sns.lmplot(data=df, x="x_1", y="x_2", hue="clustering", fit_reg=False,)
    if df['clustering'][0] == 0:
        colors = ['b', 'g']
    else:
        colors = ['g', 'b']
    plt.scatter(centers[:,0], centers[:,1], s=500, c=colors, marker=u'*' )

In [ ]:
plot_clustering_k2(centers, clustering)

K-means implementation in Scipy

Actually, Python has a builtin K-means implementation in Scipy.

Scipy is a superset of Numpy, and is installed by default in many Python distributions.


In [ ]:
from scipy.cluster.vq import kmeans,vq

In [ ]:
# distortion is the same as WCSS.
# It is called distortion in the Scipy document, since clustering can be used in compression.
centers, distortion = kmeans(points, k)

# vq return the clustering (assignment of group for each point)
# based on the centers obtained by the kmeans function.
# _ here means ignore the second return value
clustering, _ = vq(points, centers)

In [ ]:
plot_clustering_k2(centers, clustering)

Elbow method to determine a good k

Elbow method is a general rule of thumb when selecting parameters.

The idea is to that one should choose a number of clusters so that adding another cluster doesn't give much better modeling of the data


In [ ]:
df_kcurve = pd.DataFrame(columns = ['k', 'distortion']) 
for i in range(1,10):
    _, distortion = kmeans(points, i)
    df_kcurve.loc[i] = [i, distortion]

In [ ]:
df_kcurve.plot(x="k", y="distortion")

You can see that at $k=2$, there is a sharper angle.

Exercise: implement K-means++

K-means++ differs from K-means only in the initialization step.

In K-means, we randomly select k random data points as the centers at one time. One may have bad luck and get poor initializations where all k points are concentrated in one area. This could lead to a bad local optimum or take a long time to converge.

The idea of K-means++ is to select more spread-out centers. In particular, K-means++ selects k centers iteratively, one at a time. In the first iteration, it randomly choose only one random points as the 1st center. In the second iteration, we calculate the square distance between each point and the 1st center, and randomly choose the 2nd center with a probability distribution proportional to this square distance. Now suppose we have chosen $m<k$ centers, the $(m+1)$-th center is randomly chosen with a probability distribution proportional to the square distance between each point to its nearest center. The initialization step finishes when all k centers are chosen.


In [ ]:
def init_centers_kplusplus(X, k):
    
    # @YOUSE: implement the initialization step in k-means++ 
    # return centers: (k * d) matrix
    pass

In [ ]:
def kmeans_kplusplus(X, k):
    old_centers = init_centers_kplusplus(X, k)
    centers = init_centers(X, k)
    i = 1
    while not has_converged(old_centers, centers):
        old_centers = centers
        D = compute_d2_fast(X, centers)
        clustering = cluster_points(D)
        centers = update_centers(X, clustering)
        print "iteration", i, "WCSS = ", WCSS(D)
        i += 1
    return centers, clustering

In [ ]:
centers, clustering = kmeans_kplusplus(points, k)

In [ ]:
plot_clustering_k2(centers, clustering)