In [12]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt

In [13]:
from sklearn.datasets import make_moons, make_circles

# make classification problem
X, y = make_moons(n_samples=500, noise=0.3)

# represent the coordinates in several convenient ways
I, J = X.T
IJ = X[:,:2]

plt.gca().set_aspect('equal')
plt.scatter(I, J, c=y, cmap='cool')


Out[13]:
<matplotlib.collections.PathCollection at 0x7f3432f0eb50>

In [14]:
from scipy.spatial.distance import cdist

def knn(query_pts, pts, k=3):
    # brute force k nearest neighbors
    D = cdist(query_pts, pts)
    knn_ix = np.argsort(D,axis=1)[:,:k]
    row = np.tile(np.arange(query_pts.shape[0]).reshape(-1,1),k)
    knn_dist = D[row, knn_ix]
    return knn_dist, knn_ix

In [15]:
from scipy import stats

# now make a grid; we will predict a class for each grid point
resolution = 100
imin, imax = np.min(I), np.max(I)
jmin, jmax = np.min(J), np.max(J)
m = np.linspace(imin, imax, resolution)
n = np.linspace(jmin, jmax, resolution)
M, N = np.meshgrid(m,n)
MN = np.dstack((M,N)).reshape(-1,2)

# now find the k nearest neighbors for each grid point
dist, nabe = knn(MN,IJ,k=7)

# find the class of each neighbor
k_classes = y[nabe]

# take a majority vote using the mode function
pred, _ = stats.mode(k_classes,axis=1)

pred_img = pred.reshape((resolution,resolution))

plt.gca().set_aspect('equal')
plt.scatter(I, J, c=y, cmap='cool')
plt.imshow(pred_img,origin='lower',cmap='cool',extent=[imin, imax, jmin, jmax])


Out[15]:
<matplotlib.image.AxesImage at 0x7f3432e51ed0>