In [3]:
    
%matplotlib inline
import math, numpy as np, tensorflow as tf, matplotlib.pyplot as plt, operator
from importlib import reload
    
In [4]:
    
import kmeans; reload(kmeans)
from kmeans import Kmeans
    
In [5]:
    
n_clusters=6
n_samples =25000
    
In [6]:
    
centroids = np.random.uniform(-35, 35, (n_clusters, 2))
slices = [np.random.multivariate_normal(centroids[i], np.diag([5., 5.]), n_samples)
           for i in range(n_clusters)]
data = np.concatenate(slices).astype(np.float32)
    
In [7]:
    
kmeans.plot_data(centroids, data, n_samples)
    
    
In [8]:
    
k = Kmeans(data, n_clusters)
    
In [9]:
    
with tf.Session().as_default():
    %time new_centroids = k.run()
    
    
In [10]:
    
kmeans.plot_data(new_centroids, data, n_samples)
    
    
In [11]:
    
def gaussian(d, bw):
    return torch.exp(-0.5*((d/bw))**2) / (bw*math.sqrt(2*math.pi))
    
In [12]:
    
def dist_b(a,b):
    return torch.sqrt((sub(a.unsqueeze(0),b.unsqueeze(1))**2).sum(2))
    
In [13]:
    
def sum_sqz(a,axis): return a.sum(axis).squeeze(axis)
    
In [14]:
    
import torch
from torch_utils import * 
from pytorch_lshash import PyTorchLSHash
def meanshift_lsh(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    
    for it in range(5):
        pylsh = PyTorchLSHash(6, 2)
        pylsh.index(X)
        for i in range(0,n,bs):
            s = slice(i,min(n, i+bs))
            candidates = pylsh.query(s, X, bs)
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            X[s] = div(num, sum_sqz(weight, 1))
    return X
    
In [15]:
    
%time X=meanshift_lsh(data)
    
    
In [16]:
    
kmeans.plot_data(new_centroids, X.cpu().numpy(), n_samples)
    
    
In [17]:
    
def meanshift_random(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    
    for it in range(5):
        for i in range(0,n,bs):
            s = slice(i,min(n,i+bs))
            candidates = X[torch.randperm(len(X))[slice(0,bs)].long().cuda()]
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            X[s] = div(num, sum_sqz(weight, 1))
    return X
    
In [18]:
    
%time X=meanshift_random(data)
    
    
In [19]:
    
kmeans.plot_data(new_centroids, X.cpu().numpy(), n_samples)
    
    
In [22]:
    
def meanshift_lsh_record_conv(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    conv = []
    
    for it in range(10):
        pylsh = PyTorchLSHash(6, 2)
        pylsh.index(X)
        Y = torch.zeros(X.size()).cuda()
        for i in range(0,n,bs):
            s = slice(i,min(n, i+bs))
            candidates = pylsh.query(s, X, bs)
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            Y[s] = div(num, sum_sqz(weight, 1))
        conv.append(torch.sqrt(((Y-X)**2).sum(1)).sum())
        X = Y
    return X, conv
def meanshift_random_record_conv(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    conv = []
    
    for it in range(5):
        Y = torch.zeros(X.size()).cuda()
        for i in range(0,n,bs):
            s = slice(i,min(n,i+bs))
            candidates = X[torch.randperm(len(X))[slice(0,bs)].long().cuda()]
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            Y[s] = div(num, sum_sqz(weight, 1))
        conv.append(torch.sqrt(((Y-X)**2).sum(1)).sum())
        X = Y
    return X, conv
def plot_convergence():
    X1, conv1 = meanshift_lsh_record_conv(data)
    X2, conv2 = meanshift_random_record_conv(data)
    colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(1, 1, 1)
    plt.plot(conv1,  label='LSH approximation')
    plt.plot(conv2,'g--', label='random draw')
    ax.set_title('Convergence Rate', size=20)
    ax.set_xlabel('Iteration', size=16)
    plt.legend(loc="best")
    
plot_convergence()
    
    
In [ ]: