K-means clustering algorithm (unsupervised learning) for image compression

A big thank you to Andrew Ng's Coursera Machine Learning course (this is a python solution to one of his class assignments) and Isabell Graf (who made me aware of this neat problem and helped develop the code below). The Hamster's name is Mario :)

Our task

We will use K-means clustering to compress the image above to an image with fewer colours.

More precisely, each pixel is treated as a data point (regardless of its location in the image). K-means clustering will be used on the RGB values to obtain a representative average color of the K clusters. Then each pixel is replaced by the average colour of its cluster, and the image is reassembled with the transformed pixel values.

Data ingestion


In [1]:
from scipy import misc
import numpy as np

In [2]:
image = misc.imread('Mario_mit_Apfel.jpg')
print('Image size', image.shape, 'width x height x RGB')


('Image size', (900, 1200, 3), 'width x height x RGB')

Transform the image to simple array of pixels with RGB values.


In [3]:
num_pixels = image.shape[0] * image.shape[1]
pixels = image.reshape(num_pixels, 3)
print(pixels)


[[ 95  87  64]
 [ 88  82  58]
 [107 101  77]
 ..., 
 [172 175 164]
 [174 177 166]
 [173 176 165]]

Aside: plot function to show points and cluster centers


In [4]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

In [5]:
def plot_clusters(points, centers, new_centers, groups, num_to_plot=5000):
    # only plot num_to_plot random points to speed up rendering
    indices = np.random.choice(points.shape[0], num_to_plot, replace=False)
    fig = plt.figure(figsize = (10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(points[indices, 0], points[indices, 1], points[indices, 2], c=groups[indices]/255.0, depthshade=False, alpha=0.05)
    ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], marker='v', s=99, c='m', depthshade=False)
    ax.scatter(new_centers[:, 0], new_centers[:, 1], new_centers[:, 2], marker='^', s=99, c='r', depthshade=False)
    plt.show()

K-means algorithm

Step 0. Initialization


In [6]:
K = 8  # number of centroids (colours in the final output image)
max_iter = 25  # maximal number of iterations
epsilon = 0.5  # stop when centroids don't move anymore

# Initialize centroids: K random data points
centroids = pixels[np.random.choice(pixels.shape[0], K, replace=False), :]

Step 1. Alternate assignment and update


In [7]:
for iteration in range(max_iter):
    # Assignment step: determine the closest centroid for each point
    diff = pixels.reshape(num_pixels, 1, 3) - centroids
    distance = (diff ** 2).sum(axis=2)
    groups = distance.argmin(axis=1)   

    # Update step: calculate new centroids as mean of points in cluster
    new_centroids = np.zeros((K, 3))
    for k in xrange(K):
        new_centroids[k, :] = np.mean(pixels[groups == k], axis=0)
        
    
    # Report and plot
    change = ((centroids - new_centroids)**2).sum(axis=1).max()
    print('Iteration', iteration + 1, 'Change in centroid locations', change)
    plot_clusters(pixels, centroids, new_centroids, groups)
    
    # Stop the iteration if the centroids don't move anymore    
    if change < epsilon:
        break
    else:
        centroids = new_centroids


('Iteration', 1, 'Change in centroid locations', 2951.8092056273613)
('Iteration', 2, 'Change in centroid locations', 1408.5417173561013)
('Iteration', 3, 'Change in centroid locations', 285.00260915071357)
('Iteration', 4, 'Change in centroid locations', 117.11515490636154)
('Iteration', 5, 'Change in centroid locations', 40.705201890989358)
('Iteration', 6, 'Change in centroid locations', 20.593288294595553)
('Iteration', 7, 'Change in centroid locations', 31.71811667694444)
('Iteration', 8, 'Change in centroid locations', 30.042966742610336)
('Iteration', 9, 'Change in centroid locations', 18.489215811466142)
('Iteration', 10, 'Change in centroid locations', 10.559866042413205)
('Iteration', 11, 'Change in centroid locations', 7.0338964050034676)
('Iteration', 12, 'Change in centroid locations', 4.6554845496753723)
('Iteration', 13, 'Change in centroid locations', 3.3182387037862942)
('Iteration', 14, 'Change in centroid locations', 2.0306203633910269)
('Iteration', 15, 'Change in centroid locations', 1.3956227813263056)
('Iteration', 16, 'Change in centroid locations', 1.0127913665859323)
('Iteration', 17, 'Change in centroid locations', 0.71761818038487568)
('Iteration', 18, 'Change in centroid locations', 0.56570403418470794)
('Iteration', 19, 'Change in centroid locations', 0.32548793320641189)

The red and magenta triangles indicate the locations of the previous and new cluster centroids, respectively.

Put compressed image back together


In [8]:
# replace each original pixel with the RGB values of the centroid of its corresponding group.
pixels_compressed = centroids[groups, :];

image_compressed = pixels_compressed.reshape(image.shape[0], image.shape[1], image.shape[2])
misc.imsave('Mario_compressed.png', image_compressed)

In [ ]: