An Introduction to K-Means Clustering

by Scott Hendrickson & Fiona Pigott

K-Means is for learning unknown categories

K-means is a machine learning technique for learning unknown categories--in other words, a technique for unsupervised learning. K-means tries to group n-dimensional data into clusters, where the actual position of those clusters is unknown.

Basic Idea

From the Wikipedia article on k-means clustering:

"k-means clustering aims to partition n observations into k clusters in which each observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster"

Basically, k-means assumes that for some sensible distance metric, it's possible to partition data into groups around the "center" ("centroid") of different naturally separated clusters in the data.

This concept can be very useful for separating datasets that came from separate generative processes, where the location of each dataset is pretty much unknown. It only works well if there is an expectation that the datasets are clustered around their means, and that the means would reasonably be different. A classic example of where k-means would not separate datasets well is when the datasets have different distibutions, but a similar mean. Not a good problem to apply k-means to:

Good problem to apply k-means to:

Question:

Is there an underlying stucture to my data? Does my data have defined categories that I don't know about? How can I identify whuich datapoint belongs to which category?

Solution:

For a selection of centers (centroids) of data clusters (and we'll talk about how to choose centroids), for each data point, label that data point with the centroid it is closest to.

Algorithm

0) Have a dataset that you want to sort into clusters
1) Choose a number of clusters that you're going to look for (there are ways to optimize this, but you have to fix it for the next step)
2) Guess at cluster membership for each data point (basically, for each data point, randomly assign it to a cluster)
3) Find the center ("centroid") of each cluster (with the data points that you've assigned to it)
4) For each centroid, find which data points are closest to it, and assign those data points to its cluster
5) Repeat 3 & 4 (re-evaluate centroids based on new cluster memberhip, then re-assign clusters based on new centroids)

0) A cloud of data in two dimensions

Setting up an example of data that could be separated by k-means:

First, we'll generate a synthetic dataset from two different spherical gaussian distributions, setting the spacing so that clouds of data overlap a litte.


In [ ]:
# Import some python libraries that we'll need
import matplotlib.pyplot as plt
import random
import math
import sys
%matplotlib inline

In [ ]:
def make_data(n_points, n_clusters=2, dim=2, sigma=1):
    x = [[] for i in range(dim)]
    for i in range(n_clusters):
        for d in range(dim):
            x[d].extend([random.gauss(i*3,sigma) for j in range(n_points)])
    return x

In [ ]:
# make our synthetic data
num_clusters = 2
num_points = 100
data_sample = make_data(num_points, num_clusters)

In [ ]:
# plot our synthetic data
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*data_sample)
ax.set_title("Sample dataset, {} points per cluster, {} clusters".format(num_points,num_clusters))

1) How might we identify the two clusters?

We're going to set $k = 2$ before trying to use k-means to separate the clusters

We happen to know that $k = 2$ because we k=just made up this data with two distributions. I'll talk a little at the end about how to guess $k$ for a real-world dataset.

Because we created this example, we know the "truth"

We know which data came from which distribution (this is what k-means is trying to discover).

Here's the truth, just to compare:


In [ ]:
# plot our synthetic data
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(data_sample[0][0:100], data_sample[1][0:100])
ax.scatter(data_sample[0][100:200], data_sample[1][100:200])
ax.set_title("Sample dataset, {} points per cluster, {} clusters".format(num_points,num_clusters))

2) Start by guessing the cluster membership

In this case, guessing means "randomly assign cluster membership." There are other heuristics that you could use to make an initial guess, but we won't get into those here.


In [ ]:
# each cluster membership is going to have a color label ("red" cluster, "orange" cluster, etc)
co = ["red", "orange", "yellow", "green", "purple", "blue", "black","brown"]
def guess_clusters(x, n_clusters):
    # req co list of identifiers
    for i in range(len(x[0])):
        return [ co[random.choice(range(n_clusters))] for i in range(len(x[0]))]

In [ ]:
# now guess the cluster membership--simply by randomly assigning a cluster label 
# "orange" or "red" to each of the data points 
membership_2 = guess_clusters(data_sample,2)
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*data_sample, color=membership_2)
ax.set_title("Data set drawn from 2 different 2D Gaussian distributions")

3) Find the center of a set of data points

We'll need a way of determining the center of a set of data, after we make a guess at cluster membership.

In this case, we'll find the centers of the two clusters that we guessed about.


In [ ]:
def centroid(x):
    return [[sum(col)/float(len(x[0]))] for col in x]

In [ ]:
# function to select members of only one cluster
def select_members(x, membership, cluster):
    return [ [i for i,label in zip(dim, membership) if label == cluster] for dim in x ]

In [ ]:
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*select_members(data_sample, membership_2, "red"), color="red")
ax.scatter(*centroid(select_members(data_sample, membership_2, "red")), color="black", marker="*", s = 100)
ax.set_title("Centroid of the 'red' cluster (black star)")

In [ ]:
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*select_members(data_sample, membership_2, "orange"), color="orange")
ax.scatter(*centroid(select_members(data_sample, membership_2, "orange")), color="black", marker="*", s = 100)
ax.set_title("Centroid of the 'orange' cluster (black star)")

4) Update membership of points to closest centroid

Find distances (will use to find distances to the centroid, in this case):


In [ ]:
def distance(p1, p2):
    # odd... vectors are lists of lists with only 1 element in each dim
    return math.sqrt(sum([(i[0]-j[0])**2 for i,j in zip(p1, p2)]))

# here's the distance between two points, just to show how it works
print("Distance between (-1,-1) and (2,3): {}".format(distance([[-1],[-1]],[[2],[3]])))

In [ ]:
def reassign(x, centriods):
    membership = []
    for idx in range(len(x[0])):
        min_d = sys.maxsize
        cluster = ""
        for c, vc in centriods.items():
            dist = distance(vc, [[t[idx]] for t in x])
            if dist < min_d:
                min_d = dist
                cluster = c
        membership.append(cluster)
    return membership

In [ ]:
cent_2 = {i:centroid(select_members(data_sample, membership_2, i)) for i in co[:2]}
membership_2 = reassign(data_sample, cent_2)

In [ ]:
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*data_sample, color=membership_2)
ax.scatter(*cent_2["red"], color="black", marker="*", s = 360)
ax.scatter(*cent_2["orange"], color="black", marker="*", s = 360)
ax.scatter(*cent_2["red"], color="red", marker="*", s = 200)
ax.scatter(*cent_2["orange"], color="orange", marker="*", s = 200)

5) Put it all together so that we can iterate

Now we're going to iterate--assign clusters, finda centroid, reassign clusters--until the centroid positions stop changing very much.


In [ ]:
# function
def get_centroids(x, membership):
    return {i:centroid(select_members(x, membership, i)) for i in set(membership)}

# redefine with total distance measure
def reassign(x, centroids):
    membership, scores = [], {}
    # step through all the vectors
    for idx in range(len(x[0])):
        min_d, cluster = sys.maxsize, None # set the min distance to a large number (we're about to minimize it)
        for c, vc in centroids.items():
            # get the sum of the distances from each point in the cluster to the centroids
            dist = distance(vc, [[t[idx]] for t in x])
            if dist < min_d:
                min_d = dist
                cluster = c
        # score is the minumum distance from each point in a cluster to the centroid of that cluster
        scores[cluster] = min_d + scores.get(cluster, 0)
        membership.append(cluster)
        # retrun the membership & the sum of all the score over all of the clusters
    return membership, sum(scores.values())/float(len(x[0]))

def k_means(data, k):
    # start with random distribution
    membership = guess_clusters(data, k)
    score, last_score = 0.0, sys.maxsize
    while abs(last_score - score) > 1e-7:
        last_score = score
        c = get_centroids(data, membership)
        membership, score = reassign(data, c)
        #print(last_score - score)
    return membership, c, score

In [ ]:
mem, cl, s = k_means(data_sample, 2)
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.scatter(*data_sample, color = mem)
for i, pt in cl.items():
    ax.scatter(*pt, color="black", marker="*", s = 16*8)
ax.set_title("Clustering from k-means")

K-means with real data

Figuring out how many clusters to look for (that pesky "Step 1")

Now, one thing we haven't covered yet is how to decide on the number of clusters to look for in the first place. There are several different heuristics that we can use to figure out what the "best" number of clusters is (we go into this more in https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/choosing-k-in-kmeans).

The one heuristic that we're going to talk about here is finding the "knee" in the k-means error function.

The error function:

In this case, the error function is simply the sum of all of the distances from each data point to its assigned cluster, summed over all of the clusters. The further each data point is from its assigned cluster, the larger this error score is.

Look for the "knee":

When I say "knee" I mean to look for a bend in the graoh of the error score vs $k$. The idea is to find the place where you get a smaller decrease in the error (distance from each data point to a centroid) for every increase in the number of clusters ($k$).


In [ ]:
err = []
trial_ks = range(1,5)
results = {}
for k in trial_ks:
    mem_2, cl_2, s_2 = k_means(data_sample, k)
    results[k] = mem_2
    err.append(s_2)
    
f, axes = plt.subplots(1, len(trial_ks), sharey=True, figsize = (18,4))
for i,k in enumerate(trial_ks):
    axes[i].set_aspect('equal')
    axes[i].set_title("k-means results with k = {} \n error = {:f}".format(k, err[i]))
    axes[i].scatter(*data_sample, color = results[k])

In [ ]:
# plot the error as a function of the number of clusters
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.plot(trial_ks,err,'o--')
ax.set_title("Error as a funtion of k")
ax.xaxis.set_ticks(trial_ks)
_ = ax.set_xlabel("number of clusters (k)")

In [ ]:
# a different example, this time with 4 clusters
ex4 = make_data(200, 4) 
err4 = []
trial_ks_4 = range(1,9)
results_4 = {}
for k in trial_ks_4:
    mem_ex4, cl_ex4, s_ex4 = k_means(ex4, k)
    results_4[k] = mem_ex4
    err4.append(s_ex4)
    
f, axes = plt.subplots(2, int(len(trial_ks_4)/2), sharey=True, figsize = (18,11))
for i,k in enumerate(trial_ks_4):
    axes[int(i >= 4)][i%4].set_aspect('equal')
    axes[int(i >= 4)][i%4].set_title("k-means results with k = {} \n error = {:f}".format(k, err4[i]))
    axes[int(i >= 4)][i%4].scatter(*ex4, color = results_4[k])

In [ ]:
# plot the error as a function of the number of clusters
fig = plt.figure(figsize = [6,6])
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.plot(trial_ks_4,err4,'o--')
ax.set_title("Error as a funtion of k")
ax.xaxis.set_ticks(trial_ks_4)
_ = ax.set_xlabel("number of clusters (k)")

Choose the number that seems to make sense

The "number that makes sense" is where the "knee" in the error function occurs. Basically, where we start getting little or no decrease in the error, even as we increase $k$.

Here, you might see that after $k=4$, the error stops decreasing very quickly, and you might choose $k=4$ as the best number of clusters in the data.

K-means in the wild: don't use this code. Use Scikit Learn!

The tutorial on using Scikit-learn (https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/sklearn-101) and also the one on finding the number of clusters (https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/choosing-k-in-kmeans) both use the Scikit-learn K-means algorithm, which has a well-thought-out API and is way faster than a homemade version. Try it out!

Thanks for playing!


In [ ]: