K means : applications to mnist

Here we are going to propose a framework for K means method in application to the MNIST database. The algorithm we are going to use is the one of Lloyd/Voronoi. We will try to see the performance of this algorithm, and his "live" behavior to understand how the cluster are compute.

2D explanation

import numpy as np

def random_sample(X, k):
    Take a random sample of a list X to generate the initial centroids
    return X[np.random.choice(X.shape[0], k, replace=False),:]

def pairwise_distances_argmin(X, y):
    Return the closest centroids to the point y
    X are all the centroids
    indices = np.empty(X.shape[0], dtype=np.intp)
    for i in range(len(X)):
        indices[i] = np.linalg.norm(X[i,np.newaxis] - y, axis=1).argmin()
    return indices

def kmeans_iteration(X, m):
    One iteration of Lloyd's algorithm
    clusters = pairwise_distances_argmin(X, m)
    centroids = np.empty(m.shape)
    for i in range(len(m)):
        centroids[i] = np.mean(X[clusters == i], axis=0)
    return centroids, clusters

def kmeans(X, k):
    Run K_means until we are close to convergence
    m = random_sample(X, k)
    while True:
        new_m, clusters = kmeans_iteration(X, m)
        if np.isclose(m, new_m).all():
        m = new_m
    return new_m, clusters

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import rc
import seaborn as sns

sns.set_style('whitegrid') # some graphic stuff use in jupyter in python3
rc('figure', figsize=(6, 4))
cmap = cm.get_cmap('rainbow') # colors for cluster

def plot_clusters(X, m, clusters):
    Plot all the pointsi n the clusters with respective colors
    k = len(m)
    for i in range(k):
        group = X[clusters == i]
        plt.scatter(group[:,0], group[:,1], marker='.', color=cmap(i / k))
        plt.scatter(m[i,0], m[i,1], marker='s',lw=2,color=cmap(i / k),edgecolor='k')

from sklearn.datasets import make_blobs
k=6 # Number of centers
sigma = 3 # Std deviation

X, _ = make_blobs(n_samples=2000, centers=k, cluster_std=sigma)
m, clusters = kmeans(X, k)
plt.title('Results of Lloyds algorithm for {} clusters, with standard deviation {}' .format(k,sigma))
plot_clusters(X, m, clusters)

#Without cluster classification
plt.plot(X[:,0], X[:,1],
         linestyle='', marker='.',
         color=cmap(0.2), markeredgecolor=cmap(0.25))
plt.title('Blops for {} clusters, with standard deviation {} without identification' .format(k,sigma))

import matplotlib.animation as animation

class KMeansAnimation:
    def __init__(self, fig, ax, X, m=None, k=2):
        self.X = X
        self.fig = fig
        self.m = m if m is not None else random_sample(X, k)
        # We have to call plot for each cluster and its centroid
        # because we want to distinguish the clusters by color
        # and draw the centroid with a different marker
        self.clusters, self.centroids = [], []
        for i in range(k):
            color = cmap(i / k)
                ax.plot([], [],
                        linestyle='', marker='.',
                        markeredgecolor=color, color=color)[0]
                ax.plot([], [],
                        linestyle='', marker='s',
                        markersize=8, color=color,markeredgecolor='k')[0]

    def update(self, t):
        self.m, clusters = kmeans_iteration(self.X, self.m)
        self.fig.suptitle('n = {}, centers = {} – Iteration {}'.format(
                len(self.X), len(self.m), t + 1)
        # To update the plot, we simply call set_data on the saved axes
        for i in range(len(self.m)):
            group = self.X[clusters == i]
        return self.clusters + self.centroids

from IPython.display import HTML

def make_animation(X, k, m=None, frames=20):
    display animation in jupyter notebook
    fig = plt.figure(figsize=(6, 4))
    (xmin, ymin), (xmax, ymax) = np.min(X, axis=0), np.max(X, axis=0)
    ax = plt.axes(xlim=(xmin, xmax), ylim=(ymin, ymax))
    control = KMeansAnimation(fig, ax, X, m=m, k=k)
    anim = animation.FuncAnimation(
        fig, control.update,
        frames=frames, interval=700, blit=True,
    # Necessary, otherwise the notebook will display the final figure
    # along with the animation
    return HTML(anim.to_html5_video())

k=6 # Number of blops
sigma = 3 # Std deviation

X,_ = make_blobs(n_samples=2000, centers=k, cluster_std=sigma)
print ('Results of Lloyds algorithm for {} clusters, with standard deviation {}, with less centers than blops' .format(k,sigma))
make_animation(X, k=4, frames=20)

Results of Lloyds algorithm for 6 clusters, with standard deviation 3, with less centers than blops

X, _ = make_blobs(n_samples=2000, centers=6, cluster_std=2.2)
m = np.array([[0, 2.5], [3, 11], [4, 10], [5.5, 2]])
print ('Results of Lloyds algorithm for {} blops, with standard deviation {}, with close centroids at time 0' .format(k,sigma))
make_animation(X, k=4, frames=20)

Results of Lloyds algorithm for 6 blops, with standard deviation 2.2, with close centroids at time 0

X_dense = np.random.rand(2000, 2)
make_animation(X_dense, k=11)


PCA data

Bad results but allow us to plot a 2D graph and animation

import random
from base64 import b64decode
from json import loads
import numpy as np
import matplotlib.pyplot as plt
# set matplotlib to display all plots inline with the notebook
%matplotlib inline

def parse(x):
    to parse the digits file into tuples of 
    (labelled digit, numpy array of vector representation of digit)
    digit = loads(x)
    array = np.fromstring(b64decode(digit["data"]),dtype=np.ubyte)
    array = array.astype(np.float64)
    return (digit["label"], array)

# read in the digits file. Digits is a list of 60,000 tuples,
# each containing a labelled digit and its vector representation.
from sklearn import metrics
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale

with open("digits.base64.json","r") as f:
    digits = list(map(parse, f.readlines()))

for i in range(int(len(digits)*0.75)):
data = scale(data) # scale the datas
reduced_data = PCA(n_components=2).fit_transform(data) # Redced data to plot them in 2D

# Graphics stuff
print ('PCA reduced MNIST Data and K-Means algorithm for 10 cluster')
make_animation(reduced_data, k=10, frames=20)

PCA reduced MNIST Data and K-Means algorithm for 10 cluster

The question we can ask is, Is this classification accurate or does the centroids are kind of random ? To do this we will plot the label of a digit as color instead of of the cluster

for i in range (len(reduced_data)):
training = zip(label_digit,reduced_data)

colormap=np.array(['r', 'g', 'b','y','m','k','c','crimson','orange','pink'])
for i in range(10): #scatter legend
    group0 = list() ; group1 = list()
    for k in range(len(reduced_data)):
        if label_digit[k]==i:

plt.title('Reduced Mnist datas')
plt.legend(loc=9, bbox_to_anchor=(1.3, 0.6), ncol=3)

Full digits images

with open("digits.base64.json","r") as f:
    digits = list(map(parse, f.readlines()))

# pick a ratio for splitting the digits list into a training and a validation set. Classic is 0.25
ratio = int(len(digits)*0.25)
validation = digits[:ratio]
training = digits[ratio:]

validation_data = np.zeros((ratio,784))
validation_label = np.zeros((ratio,1))
for i in range(ratio):
    validation_data[i] =digits[i][1]
    validation_label[i] = digits[i][0]
training_data = np.zeros((ratio,784))
training_label = np.zeros((ratio,1))
for i in range(ratio):
    training_data[i] =digits[i+ratio][1]
    training_label[i] = digits[i+ratio][0]

m, clusters = kmeans(validation_data, 20)
centroids_indices = range(len(m))
clusters_list = {c: [] for c in centroids_indices}
for i in range(len(clusters)):
#list_cluster = clusters_list.values()
labelled_centroids = assign_labels_to_centroids(list(clusters_list.values()),m)
error_rate = get_error_rate(validation, labelled_centroids)

print (error_rate)


def display_digit(digit, labeled = True, title = ""):
    graphically displays a 784x1 vector, representing a digit
    if labeled:
        digit = digit[1]
    image = digit
    fig = plt.imshow(image.reshape(28,28))
    if title != "":
        plt.title("Inferred label: " + str(title))

# writing Lloyd's Algorithm for K-Means clustering.
# To improve performance we could use skt-learn and algo already implement
def init_centroids(labelled_data,k):
    randomly pick some k centers from the data as starting values for centroids.
    Labels are remove
    return map(lambda x: x[1], random.sample(labelled_data,k))

def sum_cluster(labelled_cluster):
    from http://stackoverflow.com/questions/20640396/quickly-summing-numpy-arrays-element-wise
    element-wise sums a list of arrays. assumes all datapoints in labelled_cluster are labelled.
    # assumes len(cluster) > 0
    sum_ = labelled_cluster[0][1].copy()
    for (label,vector) in labelled_cluster[1:]:
        sum_ += vector
    return sum_

def mean_cluster(labelled_cluster):
    computes the mean of a list of vectors (a cluster).
    assumes all datapoints in labelled_cluster are labelled.
    Should be using np.mean by I can't achieve to make it work here
    sum_of_points = sum_cluster(labelled_cluster)
    new_centroids = sum_of_points * (1.0 / len(labelled_cluster))
    return new_centroids

def form_clusters(labelled_data, unlabelled_centroids):
    given some data and centroids for the data, allocate each datapoint
    to its closest centroid. This forms clusters.
    # enumerate because centroids are arrays which are unhashable,
    centroids_indices = range(len(unlabelled_centroids))
    # initialize an empty list for each centroid. The list will contain
    # all the datapoints that are closer to that centroid than to any other.
    # That list is the cluster of that centroid.
    clusters = {c: [] for c in centroids_indices}
    for (label,Xi) in labelled_data:
        # for each datapoint, pick the closest centroid.
        smallest_distance = float("inf") # just to be sure
        for cj_index in centroids_indices:
            cj = unlabelled_centroids[cj_index]
            distance = np.linalg.norm(Xi - cj)
            if distance < smallest_distance:
                closest_centroid_index = cj_index
                smallest_distance = distance
        # allocate that datapoint to the cluster of that centroid.
    return clusters.values()

def move_centroids(labelled_clusters):
    returns a list of centroids corresponding to the new clusters.
    new_centroids = []
    for cluster in labelled_clusters:
    return new_centroids

def repeat_until_convergence(labelled_data, labelled_clusters, unlabelled_centroids,max_iter):
    form clusters around centroids, then keep moving the centroids
    until the moves are no longer significant
    previous_max_difference = 0
    while True:
        iterations +=1
        unlabelled_old_centroids = unlabelled_centroids
        unlabelled_centroids = move_centroids(labelled_clusters)
        labelled_clusters = form_clusters(labelled_data, unlabelled_centroids)
        # we keep old_clusters and clusters so we can get the maximum difference
        # between centroid positions every time. we say the centroids have converged
        # when the maximum difference between centroid positions is small.  
        # We could think of an other type of convergence
        differences = map(lambda a, b: np.linalg.norm(a-b),unlabelled_old_centroids,unlabelled_centroids)
        max_difference = max(differences)
        difference_change = abs((max_difference-previous_max_difference)/np.mean([previous_max_difference,max_difference])) * 100
        previous_max_difference = max_difference
        # difference change is nan once the list of differences is all zeroes.
        if np.isnan(difference_change) or iterations>max_iter:
    return labelled_clusters, unlabelled_centroids

def cluster(labelled_data, k,max_iter):
    runs k-means clustering on the data. It is assumed that the data is perfectly labelled.
    centroids = init_centroids(labelled_data, k)
    clusters = form_clusters(labelled_data, centroids)
    final_clusters, final_centroids = repeat_until_convergence(labelled_data, clusters, centroids,max_iter)
    return final_clusters, final_centroids

def assign_labels_to_centroids(clusters, centroids):
    Assigns a digit label to each cluster.
    Cluster is a list of clusters containing labelled datapoints.
    IMPORTANT: this function depends on clusters and centroids being in the same order.
    labelled_centroids = []
    for i in range(len(clusters)):
        labels = list(map(lambda x: x[0], clusters[i]))
        # pick the most common label
        most_common = max(set(labels), key=labels.count)
        centroid = (most_common, centroids[i])
    return labelled_centroids

def classify_digit(digit, labelled_centroids):
    given an unlabelled digit represented by a vector and a list of
    labelled centroids [(label,vector)], determine the closest centroid
    and thus classify the digit.
    mindistance = float("inf")
    for (label, centroid) in labelled_centroids:
        distance = np.linalg.norm(centroid - digit)
        if distance < mindistance:
            mindistance = distance
            closest_centroid_label = label
    return closest_centroid_label

def get_error_rate(digits,labelled_centroids):
    classifies a list of labelled digits. returns the error rate.
    classified_incorrect = 0
    for (label,digit) in digits:
        classified_label = classify_digit(digit, labelled_centroids)
        if classified_label != label:
            classified_incorrect +=1
    error_rate = classified_incorrect / float(len(digits))
    return error_rate

k = 10

trained_clusters, trained_centroids = cluster(training, k,20)
labelled_centroids = assign_labels_to_centroids(trained_clusters, trained_centroids)

for i in range(10):
  print '\033[1m' + 'Classification of the digit : ' + str(i) + ' Using ' + str(k) + ' clusters'
  label_list = []
  frequency = {x:0 for x in range(10)}

  for (label,digit) in validation:
    inferred_label = classify_digit(digit, labelled_centroids)
    if inferred_label==i:
        frequency[label] +=1
  print '\033[1m' + "Digit in the cluster        Frequency"
  for i in range(len(frequency)):
    print '\033[0m' + str(i) +"                  " + str(frequency[i])

print "Number of clusters is " + str(k)  
for x in labelled_centroids:
    display_digit(x, title=x[0])

Classification of the digit : 0 Using 10 clusters
Digit in the cluster        Frequency
0                  1357
1                  0
2                  67
3                  204
4                  5
5                  376
6                  90
7                  5
8                  61
9                  13
Classification of the digit : 1 Using 10 clusters
Digit in the cluster        Frequency
0                  3
1                  1658
2                  166
3                  104
4                  57
5                  210
6                  111
7                  82
8                  148
9                  30
Classification of the digit : 2 Using 10 clusters
Digit in the cluster        Frequency
0                  5
1                  7
2                  1040
3                  27
4                  11
5                  1
6                  110
7                  7
8                  12
9                  4
Classification of the digit : 3 Using 10 clusters
Digit in the cluster        Frequency
0                  54
1                  1
2                  53
3                  916
4                  0
5                  396
6                  7
7                  1
8                  237
9                  23
Classification of the digit : 4 Using 10 clusters
Digit in the cluster        Frequency
0                  11
1                  2
2                  41
3                  26
4                  573
5                  37
6                  45
7                  167
8                  32
9                  409
Classification of the digit : 5 Using 10 clusters
Digit in the cluster        Frequency
0                  0
1                  0
2                  0
3                  0
4                  0
5                  0
6                  0
7                  0
8                  0
9                  0
Classification of the digit : 6 Using 10 clusters
Digit in the cluster        Frequency
0                  50
1                  2
2                  23
3                  12
4                  25
5                  27
6                  1124
7                  0
8                  17
9                  3
Classification of the digit : 7 Using 10 clusters
Digit in the cluster        Frequency
0                  3
1                  8
2                  16
3                  51
4                  796
5                  116
6                  0
7                  1325
8                  73
9                  1013
Classification of the digit : 8 Using 10 clusters
Digit in the cluster        Frequency
0                  13
1                  12
2                  56
3                  208
4                  1
5                  155
6                  3
7                  6
8                  852
9                  8
Classification of the digit : 9 Using 10 clusters
Digit in the cluster        Frequency
0                  0
1                  0
2                  0
3                  0
4                  0
5                  0
6                  0
7                  0
8                  0
9                  0
Number of clusters is 10

As we can see the results are not that good, it depends on which number are we trying to detect.

 How to choose the number of cluster

First we can plot the error rate function of the number of cluser. Then we can look at the Gap statistic theory, , a method developed by Tibshirani, Walther and Hastie in 2001.

error_rates = {x:None for x in range(5,20)+[100]}
for k in range(5,20):
    trained_clusters, trained_centroids = cluster(training, k,max_iter)
    labelled_centroids = assign_labels_to_centroids(trained_clusters, trained_centroids)
    error_rate = get_error_rate(training, labelled_centroids)
    error_rates[k] = error_rate
# Show the error rates
x_axis = sorted(error_rates.keys())
y_axis = [error_rates[key] for key in x_axis]
plt.title("Error Rate by Number of Clusters")
plt.scatter(x_axis, y_axis)
plt.xlabel("Number of Clusters")
plt.ylabel("Error Rate")

import cPickle as pickle 

with open('data/donnees_kmeans','wb') as fichier:
    my_pickler= pickle.Pickler(fichier)

Now let's implement the method :

# Now we ust some function of the package to improve the algorithm in order to observed many clusters
def random_sample(X, k):
    Return a random sample of X
    return X[np.random.choice(X.shape[0], k, replace=False),:]

def pairwise_distances_argmin(X, y):
    GIve the closest cluster to all data points
    indices = np.empty(X.shape[0], dtype=np.intp)
    for i in range(len(X)):
        indices[i] = np.linalg.norm(X[i,np.newaxis] - y, axis=1).argmin()
    return indices

def kmeans_iteration(X, m):
    One iteration of the K_means method
    clusters = pairwise_distances_argmin(X, m)
    centroids = np.empty(m.shape)
    for i in range(len(m)):
        centroids[i] = np.mean(X[clusters == i], axis=0)
    return centroids, clusters

def kmeans(X, k):
    Repeat kmeans_iteration until convergence
    m = random_sample(X, k)
    repetition = 0
    while True:
        repetition +=1
        new_m, clusters = kmeans_iteration(X, m)
        if np.isclose(m, new_m).all() or repetition >50 :
        m = new_m
    return new_m, clusters

def Wk(X, centroids, clusters):
    Energy of the cluster
    return np.sum([np.linalg.norm(X[i] - centroids[clusters[i]]) ** 2
                   for i in range(len(X))])

def monte_carlo(X, xmin, xmax, ymin, ymax, n=None):
    # n is the sample size
    # monte carlo method in 2-D
    n = n if n is not None else len(X)
    xs = np.random.uniform(xmin, xmax, size=(n, 1))
    ys = np.random.uniform(ymin, ymax, size=(n, 1))
    return np.concatenate([xs, ys], axis=1)

def gap_stats(X, K=8, B=10, n=None):
    Generate all the statistic on the gaps
    It follows directly the method described in the paper
    (xmin, ymin), (xmax, ymax) = np.min(X, axis=0), np.max(X, axis=0)
    ks = np.arange(1, K + 1)
    # Generate B Monte Carlo samples (uniform) from the bounding box of X
    samples = [monte_carlo(X, xmin, xmax, ymin, ymax, n) for _ in range(B)]
    # Total energy of X for each k
    Wks = np.empty(K)
    # Mean total energy of samples for each k
    sample_Wks = np.empty(K)
    # Corrected standard deviation for each k
    sk = np.empty(K)
    for k in ks:
        Wks[k - 1] = np.log(Wk(X, *kmeans(X, k)))
        # Total energy for each sample
        current_Wks = np.empty(B)
        for i in range(B):
            sample = samples[i]
            current_Wks[i] = np.log(Wk(sample, *kmeans(sample, k)))
        sample_Wks[k - 1] = current_Wks.mean()
        sk[k - 1] = np.sqrt(((current_Wks - sample_Wks[k - 1]) ** 2).mean())
    # Correction factor
    sk *= np.sqrt(1 + 1 / B)
    gaps = sample_Wks - Wks
    return ks, Wks, sample_Wks, gaps, sk

import matplotlib.ticker as ticker

def gaps_info(X, ks, Wks, sample_Wks, gaps, sk):
    ALl the plotting options for describing the statistic
    Some help from for the plotting : http://signal-to-noise.xyz/kmeans.html
    fig, axes = plt.subplots(2, 2, figsize=(8, 7))
    axes[0,0].plot(X[:,0], X[:,1],
                   linestyle='', marker='.',
                   color=cmap(0.2), markeredgecolor=cmap(0.25))
    line1, = axes[0,1].plot(ks, Wks, marker='.', markersize=10)
    line2, = axes[0,1].plot(ks, sample_Wks, marker='.', markersize=10)
        (line1, line2),
        (r'$\log W_k$', r'$\frac{1}{B}\sum_{b = 1}^B\,\log W_{kb}^*$')
    axes[1,0].plot(ks, gaps, marker='.', markersize=10)
    gaps_diff = gaps[:-1] - gaps[1:] + sk[1:]
    barlist = axes[1,1].bar(ks[:-1], gaps_diff,
                            width=0.5, align='center')
    barlist[np.argmax(gaps_diff > 0)].set_color(sns.xkcd_rgb['pale red'])
    axes[1,1].set_ylabel('$\operatorname{Gap}(k) -'
                         ' \operatorname{Gap}(k + 1) + s_{k + 1}$')
    for (i, j) in ((0, 1), (1, 0), (1, 1)):

from time import time
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import rc
import seaborn as sns

rc('figure', figsize=(6, 4))
cmap = cm.get_cmap('rainbow')

def plot_clusters(X, m, clusters):
    k = len(m)
    for i in range(k):
        group = X[clusters == i]
        plt.scatter(group[:,0], group[:,1], marker='.', color=cmap(i / k))
        plt.scatter(m[i,0], m[i,1], marker='*', color=cmap(i / k))
for i in range(int(len(digits)*0.25)):
data = scale(data)

# Even if the results aren't as good with PCA, for this first implementation it will be useful to stay with it
# Or we need to compute Monte carlo in 784-
reduced_data = PCA(n_components=2).fit_transform(data) # Redced data to plot them in 2D

ks, Wks, sample_Wks, gaps, sk = gap_stats(reduced_data, K=15)
gaps_info(reduced_data, ks, Wks, sample_Wks, gaps, sk)

On this figure we can see the difficulty of Kmeans algorithm to mNIST data, in particularly for reduced PCA datas. We should try it on the full set.

# Goal is to hve better results by not reducing the datas
# Now we ust some function of the package to improve the algorithm in order to observed many clusters
def random_sample(X, k):
    Return a random sample of X
    #return X[np.random.choice(X.shape[0], k, replace=False),:]
    return random.sample(X,k)
def pairwise_distances_argmin(X, y):
    GIve the closest cluster to all data points
    indices = np.empty(len(X), dtype=np.intp)
    for i in range(len(X)):
        indices[i] = np.linalg.norm(X[i,np.newaxis] - y, axis=1).argmin()
    return indices

def kmeans_iteration(X, m):
    One iteration of the K_means method
    clusters = pairwise_distances_argmin(X, m)
    centroids = np.empty(m.shape)
    for i in range(len(m)):
        centroids[i] = np.mean(X[clusters == i], axis=0)
    return centroids, clusters

def kmeans(X, k):
    Repeat kmeans_iteration until convergence
    m = random_sample(X, k)
    repetition = 0
    while True:
        repetition +=1
        new_m, clusters = kmeans_iteration(X, m)
        if np.isclose(m, new_m).all() or repetition >50 :
        m = new_m
    return new_m, clusters

def monte_carlo(X,n=None):
    # n is the sample size
    xmin,xmax = np.min(X,axis=0),np.max(X,axis=0)
    n = n if n is not None else len(X)
    for i in range(len(X[0])):
        if i>1 :
            x_new = np.random.uniform(xmin[i],xmax[i],size=(n,1))
            x = np.concatenate([x,x_new],axis=1)
        else :
            x = np.random.uniform(xmin[i],xmax[i],size=(n,1))
    #xs = np.random.uniform(xmin, xmax, size=(n, 1))
    #ys = np.random.uniform(ymin, ymax, size=(n, 1))
    return x

def Wk(X, centroids, clusters):
    Energy of the cluster
    return np.sum([np.linalg.norm(X[i] - centroids[clusters[i]]) ** 2
                   for i in range(len(X))])

def gap_stats(X, K=8, B=10, n=None):
    Generate all the statistic on the gaps
    It follows directly the method described in the paper
    #(xmin, ymin), (xmax, ymax) = np.min(X, axis=0), np.max(X, axis=0)
    ks = np.arange(1, K + 1)
    # Generate B Monte Carlo samples (uniform) from the bounding box of X
    samples = [monte_carlo(X, n) for _ in range(B)]
    # Total energy of X for each k
    Wks = np.empty(K)
    # Mean total energy of samples for each k
    sample_Wks = np.empty(K)
    # Corrected standard deviation for each k
    sk = np.empty(K)
    for k in ks:
        Wks[k - 1] = np.log(Wk(X, *kmeans(X, k)))
        # Total energy for each sample
        current_Wks = np.empty(B)
        for i in range(B):
            sample = samples[i]
            current_Wks[i] = np.log(Wk(sample, *kmeans(sample, k)))
        sample_Wks[k - 1] = current_Wks.mean()
        sk[k - 1] = np.sqrt(((current_Wks - sample_Wks[k - 1]) ** 2).mean())
    # Correction factor
    sk *= np.sqrt(1 + 1 / B)
    gaps = sample_Wks - Wks
    return ks, Wks, sample_Wks, gaps, sk

import matplotlib.ticker as ticker

def gaps_info(X, ks, Wks, sample_Wks, gaps, sk):
    ALl the plotting options for describing the statistic
    Some help from for the plotting : http://signal-to-noise.xyz/kmeans.html
    fig, axes = plt.subplots(2, 2, figsize=(8, 7))
   # axes[0,0].plot(X[:,0], X[:,1],
   #                linestyle='', marker='.',
   #                color=cmap(0.2), markeredgecolor=cmap(0.25))
    #line1, = axes[0,1].plot(ks, Wks, marker='.', markersize=10)
    #line2, = axes[0,1].plot(ks, sample_Wks, marker='.', markersize=10)
        (line1, line2),
        (r'$\log W_k$', r'$\frac{1}{B}\sum_{b = 1}^B\,\log W_{kb}^*$')
    axes[1,0].plot(ks, gaps, marker='.', markersize=10)
    gaps_diff = gaps[:-1] - gaps[1:] + sk[1:]
    barlist = axes[1,1].bar(ks[:-1], gaps_diff,
                            width=0.5, align='center')
    barlist[np.argmax(gaps_diff > 0)].set_color(sns.xkcd_rgb['pale red'])
    axes[1,1].set_ylabel('$\operatorname{Gap}(k) -'
                         ' \operatorname{Gap}(k + 1) + s_{k + 1}$')
    for (i, j) in ((0, 1), (1, 0), (1, 1)):

from time import time
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import rc
import seaborn as sns

rc('figure', figsize=(6, 4))
cmap = cm.get_cmap('rainbow')

def plot_clusters(X, m, clusters):
    k = len(m)
    for i in range(k):
        group = X[clusters == i]
        plt.scatter(group[:,0], group[:,1], marker='.', color=cmap(i / k))
        plt.scatter(m[i,0], m[i,1], marker='*', color=cmap(i / k))
for i in range(int(len(digits)*0.25)):
#data = scale(data)

# Even if the results aren't as good with PCA, for this first implementation it will be useful to stay with it
# Or we need to compute Monte carlo in 784-
#reduced_data = PCA(n_components=2).fit_transform(data) # Redced data to plot them in 2D

ks, Wks, sample_Wks, gaps, sk = gap_stats(data, K=15)
gaps_info(data, ks, Wks, sample_Wks, gaps, sk)

