A big thank you to Andrew Ng's Coursera Machine Learning course (this is a python solution to one of his class assignments) and Isabell Graf (who made me aware of this neat problem and helped develop the code below). The Hamster's name is Mario :)
We will use K-means clustering to compress the image above to an image with fewer colours.
More precisely, each pixel is treated as a data point (regardless of its location in the image). K-means clustering will be used on the RGB values to obtain a representative average color of the K clusters. Then each pixel is replaced by the average colour of its cluster, and the image is reassembled with the transformed pixel values.
In [1]:
from scipy import misc
import numpy as np
In [2]:
image = misc.imread('Mario_mit_Apfel.jpg')
print('Image size', image.shape, 'width x height x RGB')
Transform the image to simple array of pixels with RGB values.
In [3]:
num_pixels = image.shape[0] * image.shape[1]
pixels = image.reshape(num_pixels, 3)
print(pixels)
In [4]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
In [5]:
def plot_clusters(points, centers, new_centers, groups, num_to_plot=5000):
# only plot num_to_plot random points to speed up rendering
indices = np.random.choice(points.shape[0], num_to_plot, replace=False)
fig = plt.figure(figsize = (10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[indices, 0], points[indices, 1], points[indices, 2], c=groups[indices]/255.0, depthshade=False, alpha=0.05)
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], marker='v', s=99, c='m', depthshade=False)
ax.scatter(new_centers[:, 0], new_centers[:, 1], new_centers[:, 2], marker='^', s=99, c='r', depthshade=False)
plt.show()
Step 0. Initialization
In [6]:
K = 8 # number of centroids (colours in the final output image)
max_iter = 25 # maximal number of iterations
epsilon = 0.5 # stop when centroids don't move anymore
# Initialize centroids: K random data points
centroids = pixels[np.random.choice(pixels.shape[0], K, replace=False), :]
Step 1. Alternate assignment and update
In [7]:
for iteration in range(max_iter):
# Assignment step: determine the closest centroid for each point
diff = pixels.reshape(num_pixels, 1, 3) - centroids
distance = (diff ** 2).sum(axis=2)
groups = distance.argmin(axis=1)
# Update step: calculate new centroids as mean of points in cluster
new_centroids = np.zeros((K, 3))
for k in xrange(K):
new_centroids[k, :] = np.mean(pixels[groups == k], axis=0)
# Report and plot
change = ((centroids - new_centroids)**2).sum(axis=1).max()
print('Iteration', iteration + 1, 'Change in centroid locations', change)
plot_clusters(pixels, centroids, new_centroids, groups)
# Stop the iteration if the centroids don't move anymore
if change < epsilon:
break
else:
centroids = new_centroids
The red and magenta triangles indicate the locations of the previous and new cluster centroids, respectively.
In [8]:
# replace each original pixel with the RGB values of the centroid of its corresponding group.
pixels_compressed = centroids[groups, :];
image_compressed = pixels_compressed.reshape(image.shape[0], image.shape[1], image.shape[2])
misc.imsave('Mario_compressed.png', image_compressed)
|
---|
In [ ]: