In [0]:
from sklearn.datasets.samples_generator import make_moons
from matplotlib import pyplot as plt
from pandas import DataFrame
import numpy as np
import torch
import time
import torch.nn as nn
import torch.nn.functional as nnf
import torchvision.transforms.functional as TF
from torch.autograd import Variable
# Draw function
def drawXW(Xc,Wc):
plt.figure()
plt.scatter(Xc[:,0], Xc[:,1], c='c',s=1)
plt.scatter(Wc[:,0], Wc[:,1], c='b')
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.gca().set_aspect('equal', adjustable='box')
plt.draw()
class KMeans(nn.Module):
def __init__(self, k, W_ini = None, dropout = 0):
super(KMeans, self).__init__() # Inherited from the parent class nn.Module
self.k = k
self.isInit = False
self.W_ini = W_ini
self.dropout = dropout
self.W = torch.tensor([])
self.D = torch.tensor([])
self.loss = []
def forward(self, X): # Forward pass
if (not self.isInit ): # First Initialization
self.isInit = True
if self.W_ini is not None:
W = self.W_ini
else:
idx = torch.randperm(len(X)).cuda()
W = X[idx[:self.k]]
self.W = W
D = self.dist2(X, self.W)
self.D = D # Save winners for updating later
return D
def update(self): # Update weights
# Calc winners
D = self.D
Dmin = D.min(dim=1, keepdim=True) # Minimum distance
S = (D==Dmin.values).to(torch.float64) # Winners
H0 = torch.zeros(1,n_units, dtype=torch.float64, device=device)
H = S.sum(dim=0, keepdim=True).to(torch.float64)
H += (H==H0).to(torch.float64).cuda()
self.W = torch.mm(S.t(), X)/H.t() # Update W
#self.loss.append = torch.sum(Dmin.values).item()
def dist2(self, X, Y):
x_norm = (X**2).sum(1).view(-1, 1)
y_norm = (Y**2).sum(1).view(1, -1)
dist = x_norm + y_norm - 2.0 * X.mm(Y.t())
return dist
# Basic params
n_samples = 10000
n_features = 2
n_units = 20
n_centers = 3
n_epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# generate dataset
#X, y = make_blobs(n_samples, n_features, n_centers )
X, y = make_moons(n_samples, noise=0.1)
X = torch.from_numpy(X).cuda()
# Kmeans
start = time.time()
km = KMeans(n_units)
for e in range(n_epochs):
activation = km(X)
km.update()
if not e:
drawXW(X.cpu(),km.W.cpu()) # initial draw
end = time.time()
print('time =',end - start)
drawXW(X.cpu(),km.W.cpu()) # initial draw
In [0]:
import torch.nn as nn
pdist = nn.PairwiseDistance(p=2)
DD = pdist(W.t(),X.t())
In [0]: