Clustering and the k-means Algorithm

Copyright 2020 Allen B. Downey

License: Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)


This notebook introduces cluster analysis and one of the most common algorithms for it, k-means.

It also introduces

  • Jupyter, which is a tool for creating notebooks like this one;

  • NumPy, which we'll use to perform array operations;

  • Pandas, which we'll use to read and clean the data; and

  • scikit-learn, which provides an implementation of k-means.

We'll proceed "top-down"; that is, we'll use scikit-learn first, then we'll open the hood and see how it works.

If you want to follow along:


I am a professor at Olin College, which is a small engineering school near Boston, Massachusetts, USA.

Olin was created in 1999 with the mission to transform engineering education.

In [1]:
<iframe src="!1m14!1m8!1m3!1d1512.1750667940496!2d-71.26457056946273!3d42.29270982134376!3m2!1i1024!2i768!4f13.1!3m3!1m2!1s0x0%3A0xa038229eeed8c35b!2sOlin%20College%20of%20Engineering!5e1!3m2!1sen!2sus!4v1594232142090!5m2!1sen!2sus" width="600" height="450" frameborder="0" style="border:0;" allowfullscreen="" aria-hidden="false" tabindex="0"></iframe>

Classes and books

I have been at Olin since 2003. I teach classes related to software, data science, Bayesian statistics, and physical modeling.

I have written several books on these topics, including Think Python and Think Stats. Most are published by O'Reilly Media, which is famous for putting animals on their covers:

But all of them are freely available from Green Tea Press.

Finally, I write a blog about Data Science and related topics, called Probably Overthinking It.

Jupyter and Colab

Jupyter is a tool for writing notebooks that contain text, code, and results.

You can install Jupyter on your own computer, but can also use services like Colab that run the notebook for you. In that case, you don't have to install anything; you just need a browser.

A notebook contains:

  • Text cells, which contain text in Markdown or HTML, and

  • Code cells, which contain code in Python or one of about 100 other languages.

This is a text cell; the one below is a code cell.

In [2]:
print('Hello, Jupyter')

On Colab, code cells have a triangle "Play" icon on the left side. You can press it to run the code in the cell.

Or if you select a cell by clicking on it, you can run it by pressing Shift-Enter.

As an exercise:

  1. Run the print statement in the previous cell.
  2. Modify the code in that cell and run it again.
  3. Run the next cell, which imports the Python modules we'll use later.

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

Also run the following cell, which defines a function we'll use.

In [4]:
def decorate(**options):
    """Decorate the current axes.
    Call decorate with keyword arguments like
    The keyword arguments can be any of the axis properties
    ax = plt.gca()
    handles, labels = ax.get_legend_handles_labels()
    if handles:
        ax.legend(handles, labels)



Cluster analysis is a set of tools for looking at data and

  • Discovering groups, species, or categories,

  • Defining boundaries between groups.

It is a form of "unsupervised" learning, which means that the only input is the dataset itself; the algorithm is not given any correct examples to learn from.

As an example, I'll used data collected and made available by Dr. Kristen Gorman at the Palmer Long-Term Ecological Research Station in Antarctica.

This dataset was published to support this article: Gorman, Williams, and Fraser, "Ecological Sexual Dimorphism and Environmental Variability within a Community of Antarctic Penguins (Genus Pygoscelis)", March 2014.

The following cell downloads the raw data.

In [5]:
# Load the data files from
# With gratitude to Allison Horst (@allison_horst)

import os

if not os.path.exists('penguins_raw.csv'):

The dataset is stored in a CSV file, which contains one row for each penguin and one column for each variable.

I'll use Pandas to read the CSV file and put the results in a DataFrame.

In [6]:
df = pd.read_csv('penguins_raw.csv')

A DataFrame is like a 2-D array, but it also contains names for the columns and labels for the rows.

The shape of the DataFrame is the number of rows and columns.

The head method displays the first few rows.

In [7]:

Three species of penguins are represented in the dataset: Adelie, Chinstrap and Gentoo, as shown in this illustration (by Allison Horst, available under the CC-BY license):

In this dataset we are told that there are three species, and we are told which species each penguin belongs to. But for purposes of clustering, we'll pretend we don't have this information and we'll see whether the algorithm "discovers" the different species.

The measurements we'll use are:

  • Body Mass in grams (g).

  • Flipper Length in millimeters (mm).

  • Culmen Length in millimeters.

  • Culmen Depth in millimeters.

If you are not familiar with the word "culmen", it refers to the top margin of the beak, as shown in the following illustration (also by Allison Horst):

This might seem like an artificial exercise. If we already know that there are three species, why are we trying to discover them?

For now, I'll just say that it's a learning example. But let's come back to this question: what is unsupervised clustering good for?

Distributions of measurements

The measurements we have will be most useful for clustering if there are substantial differences between species and small variation within species. To see whether that is true, and to what degree, I will plot distributions of measurements for each species.

For convenience, I'll create a new column, called Species2, that contains a shorter version of the species names.

In [8]:
def shorten(species):
    """Select the first word from a string."""
    return species.split()[0]

df['Species2'] = df['Species'].apply(shorten)

I'll use the groupby method to divide the dataset by species.

In [9]:
grouped = df.groupby('Species2')

The result is a GroupBy object that contains the three groups and their names. The following loop prints the group names and the number of penguins in each group.

In [10]:
for name, group in grouped:
    print(name, len(group))

We can use the GroupBy object to extract a column, like flipper length, from each group and compute its mean.

In [11]:
varname = 'Flipper Length (mm)'

for name, group in grouped:
    print(name, group[varname].mean())

We can also use it to display the distribution of values in each group.

In [12]:
for name, group in grouped:
    sns.kdeplot(group[varname], label=name)
         title='Distributions of features')

kdeplot uses kernel density estimation to make a smooth histogram of the values.

It looks like we can use flipper length to identify Gentoo penguins, but not to distinguish the other two species.

To make these steps easier to reuse, I'll wrap them a function.

In [13]:
def make_kdeplots(df, varname):
    """Make a KDE plot for each species.
    df: DataFrame
    varname: string column name
    by: string column name
    returns: dictionary from species name to Cdf
    grouped = df.groupby('Species2')
    for name, group in grouped:
        sns.kdeplot(group[varname], label=name)
         title='Distributions of features')

Now we can use it to explore other features, like culmen length.

In [14]:
make_kdeplots(df, 'Culmen Length (mm)')

It looks like we can use culmen length to identify Adelie penguins.

Exercise: Use make_kdeplots to display the distributions of one of the other two features:

  • 'Body Mass (g)'
  • 'Culmen Depth (mm)'

In [15]:
# Solution goes here

Scatter plot

If we can identify Gentoo penguins by flipper length and Adelie penguins by culmen length, maybe we can combine these variables to identify all three species.

I'll start by making a scatter plot of the data.

In [16]:
var1 = 'Flipper Length (mm)'
var2 = 'Culmen Length (mm)'
var3 = 'Culmen Depth (mm)'
var4 = 'Body Mass (g)'

In [17]:
for name, group in grouped:
    plt.plot(group[var1], group[var2], 
             'o', alpha=0.4, label=name)
decorate(xlabel=var1, ylabel=var2)

Using those two features, we can divide the penguins into clusters with not much overlap.

We're going to make lots of scatter plots, so let's wrap that code in a function.

And we'll generalize it to take by as a parameter, so we can group by any column, not just Species2.

In [18]:
def scatterplot(df, var1, var2, by):
    """Make a scatter plot.
    df: DataFrame
    var1: string column name, x-axis
    var2: string column name, y-axis
    by: string column name, groupby
    grouped = df.groupby(by)
    for name, group in grouped:
        plt.plot(group[var1], group[var2], 
                 'o', alpha=0.4, label=name)
    decorate(xlabel=var1, ylabel=var2)

Here's a scatter plot of flipper and culmen length for the three species.

In [19]:
scatterplot(df, var1, var2, 'Species2')

Exercise: Make a scatter plot using any other pair of variables.

In [20]:
# Solution goes here

We can think of these scatter plots as 2-D views of a 4-D feature space.

Clear the labels

Now, let's pretend we don't know anything about the different species, and we'll see whether we can rediscover these clusters.

To see what the problem looks like, I'll add a column of labels to the DataFrame and set it to 0 for all penguins.

In [21]:
df['labels'] = 0

Now if we group by label, there's only one big cluster.

In [22]:
scatterplot(df, var1, var2, 'labels')

Let's see what happens if we run k-means clustering on this data.


First I'll use the implementation of k-means in scikit-learn; then we'll write our own.

In the dataset, we have 344 penguins and 19 variables.

In [23]:

But some of the variables are NaN, which indicates missing data.

So I'll use dropna to drop any rows that have missing data for the two features we're going to use, flipper length and culmen length.

In [24]:
features = [var1, var2]
data = df.dropna(subset=features).copy()

I'll extract just those two columns as a NumPy array.

In [25]:
M = data[features].to_numpy()

Now we can use KMeans to identify the clusters.

n_clusters indicates how many cluster we want; this parameter is the $k$ the algorithm is named for.

In [26]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=3).fit(M)

The result is an object that contains

  • Labels that indicates which cluster each penguin is assigned to, and

  • The centers of the clusters.

I'll store the labels as a columns in data.

In [27]:
data['labels'] = kmeans.labels_

That way we can use scatterplot to show the clusters.

In [28]:
scatterplot(data, var1, var2, 'labels')

The KMeans object also contains the centers of the clusters as coordinate pairs in a NumPy array.

In [29]:

To plot the centers, I'll transpose the array and assign the columns to x and y:

In [30]:
xs, ys = np.transpose(kmeans.cluster_centers_)

I'll plot the centers with x's and o's.

In [31]:
options = dict(color='C3', ls='none', mfc='none')
plt.plot(xs, ys, marker='o', ms=15, **options)
plt.plot(xs, ys, marker='x', ms=10, **options);

As usual, let wrap that up in a function.

In [32]:
def plot_centers(centers, color='C3'):
    """Plot cluster centers.
    centers: array with x and y columns
    color: string color specification
    xs, ys = np.transpose(centers)
    options = dict(color=color, ls='none', mfc='none')
    plt.plot(xs, ys, marker='o', ms=15, **options)
    plt.plot(xs, ys, marker='x', ms=10, **options)

Now let's pull it all together.

In [33]:
scatterplot(data, var1, var2, 'labels')

This figure shows the data, color-coded by assigned label, and the centers of the clusters.

It looks like k-means does a reasonable job of rediscovering the species, but with some confusion between Adelie (lower left) and Chinstrap (top center).

As a reminder, here are the right answers:

In [34]:
scatterplot(data, var1, var2, 'Species2')

Note that the color coding for the clusters is not consistent because the centers we get from k-means are in a random order.

Exercise: Here's the code from this section all in one place. Modify it to use any two features and see what the results look like.

In [35]:
features2 = [var1, var2]

data2 = df.dropna(subset=features2).copy()

M2 = data2[features2].to_numpy()

kmeans2 = KMeans(n_clusters=3).fit(M2)

data2['labels'] = kmeans2.labels_

scatterplot(data2, var1, var2, 'labels')

Implementing k-means

Now let's see how the algorithm works. At a high level, there are three steps:

  1. Choose $k$ random points in the dataset as initial centers.
  2. Assign each point in the dataset to the closest center.
  3. Compute new centers by calculating the "center of mass" in each cluster.

Then you repeat steps 2 and 3 until the centers stop moving.

To select random points from the dataset, I'll use np.random.choice to select three indices.

In [36]:
index = np.random.choice(len(M), size=3)

And then use the indices to select rows from the dataset.

In [37]:
centers = M[index]

I'll wrap that in a function:

In [38]:
def choose_random_start(M, k):
    """Choose k random elements of M.
    M: NumPy array with rows of coordinates
    k: number of centers to choose
    returns: NumPy array
    index = np.random.choice(len(M), size=k)
    centers = M[index]
    return centers

Here's how we use it.

In [39]:
centers = choose_random_start(M, 3)

And here's what the centers look like on the scatterplot.

In [40]:
data['labels'] = 0
scatterplot(data, var1, var2, 'labels')

The next step is to assign each point to the closest center. So we need to compute the distance between each point and each of the centers.

Compute distances

To demonstrate the process, I'll pick just one of the centers.

In [41]:
center_x, center_y = centers[0]
center_x, center_y

Now it will be convenient to have the x and y coordinates in separate arrays. I can do that with np.transpose, which turns the columns into rows; then I can assign the rows to x and y.

In [42]:
x, y = np.transpose(M)

Along the x-axis, the distance from each point to this center is x-center_x.

Along the y-axis, the distance is y-center_y.

The distance from each point to the center is the hypotenuse of the triangle, which I can compute with np.hypot:

In [43]:
distances = np.hypot(x-center_x, y-center_y)

The result is an array that contains the distance from each point to the chosen center.

To see if we got it right, I'll plot the center and the points, with the size of the points proportional to distance.

In [44]:
plt.plot(center_x, center_y, 'rx', markersize=10)
plt.scatter(x, y, s=distances)

decorate(xlabel=var1, ylabel=var2)

At least visually, it seems like the size of the points is proportional to their distance from the center.

So let's put those steps into a function:

In [45]:
def compute_distances(M, center):
    """Compute distances to the given center.
    M: NumPy array of coordinates
    center: x, y coordinates of the center
    returns: NumPy array of float distances
    x, y = np.transpose(M)
    center_x, center_y = center
    distances = np.hypot(x-center_x, y-center_y)
    return distances

We can use the function to make a list of distance arrays, one for each center.

In [46]:
distance_arrays = [compute_distances(M, center)
                   for center in centers]

Labeling the points

The next step is to label each point with the index of the center it is closest to.

distance_arrays is a list of arrays, but we can convert it to a 2-D array like this:

In [47]:
A = np.array(distance_arrays)

A has one row for each center and one column for each point.

Now we can use np.argmin to find the shortest distance in each column and return its index.

In [48]:
data['labels'] = np.argmin(A, axis=0)

The result is an array of indices in the range 0..2, which we assign to a column in data.

Let's put these steps in a function.

In [49]:
def compute_labels(M, centers):
    """Label each point with the index of the closest center.
    M: NumPy array of coordinates
    centers: array of coordinates for the centers
    returns: array of labels, 0..k-1
    distance_arrays = [compute_distances(M, center)
                       for center in centers]
    A = np.array(distance_arrays)
    labels = np.argmin(A, axis=0)
    return labels

We can call it like this:

In [50]:
data['labels'] = compute_labels(M, centers)

And here are the results.

In [51]:
scatterplot(data, var1, var2, 'labels')

If we get lucky, we might start with one point near the center of each cluster. But even if we are unlucky, we can improve the results by recentering.

Find new centers

The last step is to use the labels from the previous step to compute the center of each cluster.

I'll start by using groupby to group the points by label.

In [52]:
grouped = data.groupby('labels')
for name, group in grouped:
    print(name, len(group))

We can use the GroupBy object to select the columns we're using and compute their means.

In [53]:

The result is a DataFrame that contains the central coordinates of each cluster.

I'll put these steps in a function.

In [54]:
def compute_new_centers(data, features):
    """Compute the center of each cluster.
    data: DataFrame
    features: list of string column names
    means = data.groupby('labels')[features].mean()
    return means.to_numpy()

The return value is a NumPy array that contains the new centers.

In [55]:
new_centers = compute_new_centers(data, features)

Here's what it looks like with the old centers in gray and the new centers in red.

In [56]:
scatterplot(data, var1, var2, 'labels')
plot_centers(centers, color='gray')
plot_centers(new_centers, color='C3')

The k-means algorithm

Now here's the whole algorithm in one function.

In [57]:
def k_means(data, features, k):
    """Cluster by k means.
    data: DataFrame
    features: list of string column names
    k: number of clusters
    returns: array of centers
    M = data[features].to_numpy()
    centers = choose_random_start(M, k)

    for i in range(15):
        data['labels'] = compute_labels(M, centers)
        centers = compute_new_centers(data, features)

    return centers

And here's what the results look like after 15 iterations.

In [58]:
centers = k_means(data, features, 3)
scatterplot(data, var1, var2, 'labels')
plot_centers(centers, color='C3')

The results are (as far as I can see) identical to what we got from the scikit-learn implementation.

In [59]:
kmeans = KMeans(n_clusters=3).fit(M)
data['labels'] = kmeans.labels_

scatterplot(data, var1, var2, 'labels')


Here's an animation that shows the algorithm in action.

In [60]:
from time import sleep
from IPython.display import clear_output
interval = 1

centers = choose_random_start(M, k=3)

for i in range(10):
    # label and scatter plot
    data['labels'] = compute_labels(M, centers)
    scatterplot(data, var1, var2, 'labels')
    plot_centers(centers, color='gray')
    # compute new centers and plot them
    new_centers = compute_new_centers(data, features)        
    centers = new_centers
    # show the plot, wait, and clear

Exercise: Run the previous cell a few times. Do you always get the same clusters?

Number of clusters

All of this is based on the assumption that you know how many clusters you are looking for, which is true for some applications, but not always.

Let's see what goes wrong if you ask for too many clusters, or too few.

Exercise: Run the following code with different values of n_clusters and see what the results look like.

In [61]:
kmeans = KMeans(n_clusters=3).fit(M)
data['labels'] = kmeans.labels_

scatterplot(data, var1, var2, 'labels')


One of the problems with the results we have seen so far is that the lines between the clusters are mostly vertical.

That's because the range of values is wider for flipper length than culmen length, about 60 mm compared to 28 mm.

In [62]:
M.max(axis=0) - M.min(axis=0)

In [63]:

When we compute the distance from each point to each center, the distances in the $x$ direction tend to dominate.

This is a common problem with algorithms that are based on distance in multidimensional space.

It is such a common problem that there is a common solution: feature scaling.

The goal of feature scaling is to transform the features so the distances along each axis are comparable.

One version of feature scaling is "standardization", which consists of

  1. Subtracting the mean from each feature, and
  2. Dividing through by the standard deviation.

Here's how we can do it with the features in M:

In [64]:
means = M.mean(axis=0)

In [65]:
stds = M.std(axis=0)

In [66]:
M_std = (M - means) / stds

Let's see what happens if we run the algorithm again with standardized features.

Notice that I have to transform the centers back before plotting them.

In [67]:
kmeans = KMeans(n_clusters=3).fit(M_std)
data['labels'] = kmeans.labels_

scatterplot(data, var1, var2, 'labels')

centers = kmeans.cluster_centers_ * stds + means

That looks a lot better! Again, here are the actual species for comparison.

In [68]:
scatterplot(data, var1, var2, 'Species2')

scikit-learn provides StandardScaler, which does the same thing.

In [69]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler().fit(M)
M_std = scaler.transform(M)

And scaler provides inverse_transform, which we can use to transform the centers back.

In [70]:
kmeans = KMeans(n_clusters=3).fit(M_std)
data['labels'] = kmeans.labels_

scatterplot(data, var1, var2, 'labels')

centers = scaler.inverse_transform(kmeans.cluster_centers_)


The k-means algorithm does unsupervised clustering, which means that we don't tell it where the clusters are; we just provide the data and ask it to find a given number of clusters.

In this notebook, we asked it to find clusters in a group of penguins based on two features, flipper length and culmen length. The clusters it finds reflect the species in the dataset, especially if we standardize the data.

In this example we used only two features, because that makes it easy to visualize the results. But k-means extends easily to any number of dimensions (see the exercise below).

So, what is this good for?

Well, Wikipedia provides this list of applications. Applying clustering analysis to these applications, I see a few general ideas:

  • From an engineering point of view, clustering can be used to automate some kinds of analysis people do, which might be faster, more accurate, or less expensive. And it can work with large datasets and high numbers of dimensions that people can't handle.

  • From a scientific point of view, clustering provides a way to test whether the patterns we see are in the data or in our minds.

This second point is related to old philosophical questions about the nature of categories. Putting things into categories seems to be a natural part of how humans think, but we have to wonder whether the categories we find truly "carve nature at its joints", as Plato put it.

If a clustering algorithm finds the same "joints" we do, we might have more confidence they are not entirely in our minds.

Exercise: Use the scikit-learn implementation of k-means to find clusters using all four features (flipper length, culmen length and depth, body mass). How do the results compare to what we got with just two features?

In [71]:
# Solution goes here

In [72]:
# Solution goes here

In [73]:
# Solution goes here

In [74]:
# Solution goes here

In [75]:
# Solution goes here

In [76]:
# Solution goes here

In [ ]: