K-means for image compression

This notebook takes you through a potential use of k-means clustering; image segmentation and compression.


In [17]:
%pylab inline
from sklearn.datasets import load_sample_image
from sklearn.cluster import KMeans


Populating the interactive namespace from numpy and matplotlib

In [18]:
img = load_sample_image("china.jpg")

In [19]:
img.shape, img.dtype


Out[19]:
((427, 640, 3), dtype('uint8'))

In [20]:
print("Naive storage of the data in this jpg would take {} KiB".format(427 * 640 * 3 / 1024))


Naive storage of the data in this jpg would take 800.625 KiB

In [21]:
# Convert to floats instead of the default 8 bits integer coding. Dividing by
# 255 is important so that imshow behaves works well on float data (need to
# be in the range [0-1]
china = np.array(img, dtype=np.float64) / 255

imshow(china)


Out[21]:
<matplotlib.image.AxesImage at 0x7f37e7652898>

In [22]:
n_colors = 128

# Transform image into a 2D numpy array
w, h, d = original_shape = tuple(china.shape)
assert d == 3
image_array = np.reshape(china, (w * h, d))

image_array.shape


Out[22]:
(273280, 3)

In [23]:
k_means = KMeans(init='k-means++', n_clusters=n_colors)

# Can either fit the model to a subset of the image, or all of it:
# As you might expect fitting the whole image takes a lot longer
#k_means.fit(image_array)

image_array_sample = image_array.copy()
np.random.shuffle(image_array_sample)
image_array_sample = image_array_sample[:1000]
k_means.fit(image_array_sample)


Out[23]:
KMeans(copy_x=True, init='k-means++', max_iter=300, n_clusters=128, n_init=10,
    n_jobs=1, precompute_distances=True, random_state=None, tol=0.0001,
    verbose=0)

In [24]:
colorPallete = k_means.cluster_centers_

imshow(colorPallete.reshape((n_colors/8, 8, 3)), interpolation="none")
title("Color Pallet");



In [25]:
# Get labels for all points
print("Predicting color indices on the full image using k-means")
labels = k_means.predict(image_array)
labels = labels.astype(np.uint8)
# Assuming we use less than 256 colors labels will fit in a byte (uint8)


Predicting color indices on the full image using k-means

In [26]:
# Now our compressed image is represented by two pieces of information

# The color labels:
labels.shape, labels.dtype


Out[26]:
((273280,), dtype('uint8'))

In [27]:
# And the means (color descriptions)
k_means.cluster_centers_.shape, k_means.cluster_centers_.dtype


Out[27]:
((128, 3), dtype('float64'))

In [28]:
print("K-Means storage of the data would take {} KiB".format(((n_colors * 3 * 8) + (len(labels) * 1)) / 1024))


K-Means storage of the data would take 269.875 KiB

In [29]:
def recreate_image(codebook, labels, w, h):
    """Recreate the (compressed) image from the code book & labels"""
    d = codebook.shape[1]
    image = np.zeros((w, h, d))
    label_idx = 0
    for i in range(w):
        for j in range(h):
            image[i][j] = codebook[labels[label_idx]]
            label_idx += 1
    return image

In [30]:
title('Quantized image ({} colors, K-Means)'.format(n_colors))
imshow(recreate_image(k_means.cluster_centers_, labels, w, h));


This use of k-means isn't a particularly sophisticated image compression scheme, a better algorithm would amost certainly take the spatial proximity of the pixels into account. Also the compression is a lossy scheme, as you can easily verify if you set a low K value (like 16).