In [3]:
%matplotlib inline
import math, numpy as np, tensorflow as tf, matplotlib.pyplot as plt, operator
from importlib import reload

In [4]:
import kmeans; reload(kmeans)
from kmeans import Kmeans

In [5]:
n_clusters=6
n_samples =25000

In [6]:
centroids = np.random.uniform(-35, 35, (n_clusters, 2))
slices = [np.random.multivariate_normal(centroids[i], np.diag([5., 5.]), n_samples)
           for i in range(n_clusters)]
data = np.concatenate(slices).astype(np.float32)

In [7]:
kmeans.plot_data(centroids, data, n_samples)


Kmeans


In [8]:
k = Kmeans(data, n_clusters)

In [9]:
with tf.Session().as_default():
    %time new_centroids = k.run()


CPU times: user 680 ms, sys: 108 ms, total: 788 ms
Wall time: 802 ms

In [10]:
kmeans.plot_data(new_centroids, data, n_samples)


Mean shift with LSH


In [11]:
def gaussian(d, bw):
    return torch.exp(-0.5*((d/bw))**2) / (bw*math.sqrt(2*math.pi))

In [12]:
def dist_b(a,b):
    return torch.sqrt((sub(a.unsqueeze(0),b.unsqueeze(1))**2).sum(2))

In [13]:
def sum_sqz(a,axis): return a.sum(axis).squeeze(axis)

In [14]:
import torch
from torch_utils import * 
from pytorch_lshash import PyTorchLSHash

def meanshift_lsh(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    
    for it in range(5):
        pylsh = PyTorchLSHash(6, 2)
        pylsh.index(X)
        for i in range(0,n,bs):
            s = slice(i,min(n, i+bs))
            candidates = pylsh.query(s, X, bs)
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            X[s] = div(num, sum_sqz(weight, 1))
    return X

In [15]:
%time X=meanshift_lsh(data)


CPU times: user 6.6 s, sys: 944 ms, total: 7.55 s
Wall time: 7.54 s

In [16]:
kmeans.plot_data(new_centroids, X.cpu().numpy(), n_samples)


Meanshift with random draw


In [17]:
def meanshift_random(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    
    for it in range(5):
        for i in range(0,n,bs):
            s = slice(i,min(n,i+bs))
            candidates = X[torch.randperm(len(X))[slice(0,bs)].long().cuda()]
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            X[s] = div(num, sum_sqz(weight, 1))
    return X

In [18]:
%time X=meanshift_random(data)


CPU times: user 3.46 s, sys: 8 ms, total: 3.46 s
Wall time: 3.46 s

In [19]:
kmeans.plot_data(new_centroids, X.cpu().numpy(), n_samples)


Comparison in Convergence


In [22]:
def meanshift_lsh_record_conv(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    conv = []
    
    for it in range(10):
        pylsh = PyTorchLSHash(6, 2)
        pylsh.index(X)
        Y = torch.zeros(X.size()).cuda()
        for i in range(0,n,bs):
            s = slice(i,min(n, i+bs))
            candidates = pylsh.query(s, X, bs)
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            Y[s] = div(num, sum_sqz(weight, 1))
        conv.append(torch.sqrt(((Y-X)**2).sum(1)).sum())
        X = Y
    return X, conv


def meanshift_random_record_conv(data, bs=500):
    n = len(data)
    X = torch.FloatTensor(np.copy(data)).cuda()
    conv = []
    
    for it in range(5):
        Y = torch.zeros(X.size()).cuda()
        for i in range(0,n,bs):
            s = slice(i,min(n,i+bs))
            candidates = X[torch.randperm(len(X))[slice(0,bs)].long().cuda()]
            weight = gaussian(dist_b(candidates, X[s]), 2)
            num = sum_sqz(mul(weight, candidates), 1)
            Y[s] = div(num, sum_sqz(weight, 1))
        conv.append(torch.sqrt(((Y-X)**2).sum(1)).sum())
        X = Y
    return X, conv


def plot_convergence():
    X1, conv1 = meanshift_lsh_record_conv(data)
    X2, conv2 = meanshift_random_record_conv(data)
    colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(1, 1, 1)
    plt.plot(conv1,  label='LSH approximation')
    plt.plot(conv2,'g--', label='random draw')
    ax.set_title('Convergence Rate', size=20)
    ax.set_xlabel('Iteration', size=16)
    plt.legend(loc="best")
    
plot_convergence()



In [ ]: