In [ ]:
import random
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from PIL import Image
In [ ]:
%matplotlib inline
In [ ]:
ex7data1 = scipy.io.loadmat('ex7data2.mat')
X = ex7data1['X']
In [ ]:
def find_closest_centroids(X, centroids):
#FINDCLOSESTCENTROIDS computes the centroid memberships for every example
# idx = FINDCLOSESTCENTROIDS (X, centroids) returns the closest centroids
# in idx for a dataset X where each row is a single example. idx = m x 1
# vector of centroid assignments (i.e. each entry in range [1..K])
#
K = centroids.shape[0]
# You need to return the following variables correctly.
idx = np.zeros(X.shape[0], dtype='int')
# ====================== YOUR CODE HERE ======================
# Instructions: Go over every example, find its closest centroid, and store
# the index inside idx at the appropriate location.
# Concretely, idx(i) should contain the index of the centroid
# closest to example i. Hence, it should be a value in the
# range 1..K
#
# Note: You can use a for-loop over the examples to compute this.
#
# =============================================================
return idx
The closest centroids to the first 3 examples should be [0, 2, 1]
respectively.
In [ ]:
K = 3
initial_centroids = np.array([[3, 3], [6, 2], [8, 5]])
idx = find_closest_centroids(X, initial_centroids)
idx[:3]
In [ ]:
def compute_centroids(X, idx, K):
#COMPUTECENTROIDS returs the new centroids by computing the means of the
#data points assigned to each centroid.
# centroids = COMPUTECENTROIDS(X, idx, K) returns the new centroids by
# computing the means of the data points assigned to each centroid. It is
# given a dataset X where each row is a single data point, a vector
# idx of centroid assignments (i.e. each entry in range [1..K]) for each
# example, and K, the number of centroids. You should return a matrix
# centroids, where each row of centroids is the mean of the data points
# assigned to it.
#
m, n = X.shape
# You need to return the following variables correctly.
centroids = np.zeros((K, n))
# ====================== YOUR CODE HERE ======================
# Instructions: Go over every centroid and compute mean of all points that
# belong to it. Concretely, the row vector centroids(i, :)
# should contain the mean of the data points assigned to
# centroid i.
#
# Note: You can use a for-loop over the centroids to compute this.
#
# =============================================================
return centroids
Centroids computed after initial finding of closest centroids:
In [ ]:
compute_centroids(X, idx, K)
The centroids should be:
array([[ 2.42830111, 3.15792418],
[ 5.81350331, 2.63365645],
[ 7.11938687, 3.6166844 ]])
In [ ]:
def plot_data_points(X, idx, K, ax):
palette = plt.get_cmap('hsv', np.max(idx) + 2)
colors = palette(idx)
ax.scatter(X[:, 0], X[:, 1], c=colors)
In [ ]:
def plot_kmeans_progress(X, centroids, previous_centroids, idx, K, iteration_number, ax):
plot_data_points(X, idx, K, ax)
ax.scatter(centroids[:, 0], centroids[:, 1], c='black', marker='x', s=50, color='black', linewidths=4)
if previous_centroids is not None:
for c, pc in zip(centroids, previous_centroids):
ax.plot([c[0], pc[0]], [c[1], pc[1]], 'b-')
ax.set_title('Iteration {}'.format(iteration_number))
In [ ]:
K = 3
initial_centroids = np.array([[3, 3], [6, 2], [8, 5]])
max_iters = 10
In [ ]:
def run_kmeans(X, initial_centroids, max_iters, plot_progress=False):
if plot_progress:
fig, ax = plt.subplots(figsize=(6, 6))
m, n = X.shape
K = initial_centroids.shape[0]
centroids = initial_centroids
previous_centroids = None
for i in range(max_iters):
idx = find_closest_centroids(X, centroids)
if plot_progress:
plot_kmeans_progress(X, centroids, previous_centroids, idx, K, i+1, ax)
previous_centroids = centroids
centroids = compute_centroids(X, idx, K)
return centroids, idx
In [ ]:
_, __ = run_kmeans(X, initial_centroids, 10, True)
In [ ]:
# Load an image of a bird
im = Image.open('bird_small.png')
X = np.array(im)
X = X/255 # Divide by 255 so that all values are in the range
Reshape the image into an Nx3 matrix where N = number of pixels. Each row will contain the Red, Green and Blue pixel values This gives us our dataset matrix X that we will use K-Means on.
In [ ]:
X = X.reshape((128*128, 3))
img_size = X.shape
img_size, X.dtype
You should now complete the code in kmeans_init_centroids.
In [ ]:
def kmeans_init_centroids(X, K):
#KMEANSINITCENTROIDS This function initializes K centroids that are to be
#used in K-Means on the dataset X
# centroids = KMEANSINITCENTROIDS(X, K) returns K initial centroids to be
# used with the K-Means on the dataset X
#
centroids = np.zeros((K, X.shape[1]))
# ====================== YOUR CODE HERE ======================
# Instructions: You should set centroids to randomly chosen examples from
# the dataset X
#
# =============================================================
return centroids
Run your K-Means algorithm on this data. You should try different values of K and max_iters here
In [ ]:
K = 16
max_iters = 10
When using K-Means, it is important the initialize the centroids randomly. You should complete the code in kMeansInitCentroids before proceeding
In [ ]:
initial_centroids = kmeans_init_centroids(X, K)
In [ ]:
centroids, idx = run_kmeans(X, initial_centroids, max_iters)
In [ ]:
idx = find_closest_centroids(X, centroids)
Essentially, now we have represented the image X as in terms of the indices in idx.
We can now recover the image from the indices (idx) by mapping each pixel (specified by it's index in idx) to the centroid value
In [ ]:
X_recovered = centroids[idx,:]
X_recovered = X_recovered.reshape([128, 128, 3])
X_recovered *= 255
X_recovered = np.array(X_recovered, dtype='uint8')
X_recovered.shape
Here are the centroid colors:
In [ ]:
fig, axes = plt.subplots(nrows=4, ncols=4)
axes = axes.flat
for centroid, ax in zip(centroids, axes):
c = np.array(centroid)
ax.set_axis_off()
ax.scatter(1,1,c=c,s=1000)
And the images, original and compressed.
In [ ]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,10))
axes[0].imshow(X_recovered)
axes[0].set_title('Compressed')
axes[1].imshow(np.array(Image.open('bird_small.png')))
axes[1].set_title('Original')
for ax in axes:
ax.set_axis_off()