# K-means clustering algorithm (unsupervised learning) for image compression

A big thank you to Andrew Ng's Coursera Machine Learning course (this is a python solution to one of his class assignments) and Isabell Graf (who made me aware of this neat problem and helped develop the code below). The Hamster's name is Mario :)

We will use K-means clustering to compress the image above to an image with fewer colours.

More precisely, each pixel is treated as a data point (regardless of its location in the image). K-means clustering will be used on the RGB values to obtain a representative average color of the K clusters. Then each pixel is replaced by the average colour of its cluster, and the image is reassembled with the transformed pixel values.

## Data ingestion

``````

In [1]:

from scipy import misc
import numpy as np

``````
``````

In [2]:

print('Image size', image.shape, 'width x height x RGB')

``````
``````

('Image size', (900, 1200, 3), 'width x height x RGB')

``````

Transform the image to simple array of pixels with RGB values.

``````

In [3]:

num_pixels = image.shape[0] * image.shape[1]
pixels = image.reshape(num_pixels, 3)
print(pixels)

``````
``````

[[ 95  87  64]
[ 88  82  58]
[107 101  77]
...,
[172 175 164]
[174 177 166]
[173 176 165]]

``````

### Aside: plot function to show points and cluster centers

``````

In [4]:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

``````
``````

In [5]:

def plot_clusters(points, centers, new_centers, groups, num_to_plot=5000):
# only plot num_to_plot random points to speed up rendering
indices = np.random.choice(points.shape[0], num_to_plot, replace=False)
fig = plt.figure(figsize = (10, 8))
ax.scatter(points[indices, 0], points[indices, 1], points[indices, 2], c=groups[indices]/255.0, depthshade=False, alpha=0.05)
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], marker='v', s=99, c='m', depthshade=False)
ax.scatter(new_centers[:, 0], new_centers[:, 1], new_centers[:, 2], marker='^', s=99, c='r', depthshade=False)
plt.show()

``````

## K-means algorithm

Step 0. Initialization

``````

In [6]:

K = 8  # number of centroids (colours in the final output image)
max_iter = 25  # maximal number of iterations
epsilon = 0.5  # stop when centroids don't move anymore

# Initialize centroids: K random data points
centroids = pixels[np.random.choice(pixels.shape[0], K, replace=False), :]

``````

Step 1. Alternate assignment and update

``````

In [7]:

for iteration in range(max_iter):
# Assignment step: determine the closest centroid for each point
diff = pixels.reshape(num_pixels, 1, 3) - centroids
distance = (diff ** 2).sum(axis=2)
groups = distance.argmin(axis=1)

# Update step: calculate new centroids as mean of points in cluster
new_centroids = np.zeros((K, 3))
for k in xrange(K):
new_centroids[k, :] = np.mean(pixels[groups == k], axis=0)

# Report and plot
change = ((centroids - new_centroids)**2).sum(axis=1).max()
print('Iteration', iteration + 1, 'Change in centroid locations', change)
plot_clusters(pixels, centroids, new_centroids, groups)

# Stop the iteration if the centroids don't move anymore
if change < epsilon:
break
else:
centroids = new_centroids

``````
``````

('Iteration', 1, 'Change in centroid locations', 2951.8092056273613)

('Iteration', 2, 'Change in centroid locations', 1408.5417173561013)

('Iteration', 3, 'Change in centroid locations', 285.00260915071357)

('Iteration', 4, 'Change in centroid locations', 117.11515490636154)

('Iteration', 5, 'Change in centroid locations', 40.705201890989358)

('Iteration', 6, 'Change in centroid locations', 20.593288294595553)

('Iteration', 7, 'Change in centroid locations', 31.71811667694444)

('Iteration', 8, 'Change in centroid locations', 30.042966742610336)

('Iteration', 9, 'Change in centroid locations', 18.489215811466142)

('Iteration', 10, 'Change in centroid locations', 10.559866042413205)

``````