Color Quantization using k-Means

This code is provided as supplementary material of the lecture Machine Learning and Optimization in Communications (MLOC).

This code illustrates

  • The use of the k-means algorithm to quantize 24-bit RGB colors of a bitmap image

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.cluster import KMeans
%matplotlib inline

Load an image, needs to be an RGB image (24 bit color information), best in png format to have unaltered color information


In [2]:
image = mpimg.imread('Mandrill.png')
plt.figure(1,figsize=(10,10))
plt.imshow(image)
plt.title('Original image (24bit color information)')
plt.show()


Quantize the color information of the image. For this, we make use of the k-means algorithm by clustering the 3d-color space into num_colors distinct clusters. We then replace the image by the respective clusters


In [3]:
def quantize_colors(image,num_colors):
    temp = np.reshape(image, (image.shape[0]*image.shape[1],3))
    
    kmeans = KMeans(n_clusters=num_colors, random_state=0, n_jobs=8, n_init=8, precompute_distances=True).fit(temp)

    retval = kmeans.cluster_centers_[kmeans.labels_]
    return np.reshape(retval,(image.shape[0],image.shape[1],3))

Visualize the quantized images using 4, 16 and 64 numbers (any number would be possible)


In [5]:
image_q = quantize_colors(image,4)
plt.figure(1,figsize=(20,10))
plt.subplot(121)
plt.imshow(image)
plt.title('Original')
plt.subplot(122)
plt.imshow(image_q)
plt.title('Image with 4 colors')
plt.show()



In [6]:
image_q = quantize_colors(image,16)
plt.figure(1,figsize=(20,20))
plt.subplot(121)
plt.imshow(image)
plt.title('Original')
plt.subplot(122)
plt.imshow(image_q)
plt.title('Image with 16 colors')
plt.show()



In [6]:
image_q = quantize_colors(image,50)
plt.figure(1,figsize=(20,20))
plt.subplot(121)
plt.imshow(image)
plt.title('Original')
plt.subplot(122)
plt.imshow(image_q)
plt.title('Image with 50 colors')
plt.show()