The K-nearest neighbor algorithm

In this notebook we focus on reproducing the result of Fig. 2.15.


In [5]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import euclidean
from tqdm import tqdm
from time import sleep

%matplotlib inline
LARGE_SIZE = (12,8)

Generating the data

We use the make_moons method from scikit-learn, which allows to generate directly a set of points and the associated labels.


In [6]:
from sklearn.datasets import make_moons

In [7]:
X, y = make_moons(n_samples=100, noise=0.2)

In [8]:
colz = {0: 'dodgerblue', 1:"goldenrod"}

for pt, cl in zip(X, y):
    plt.scatter(pt[0],pt[1], color=colz[cl])

plt.scatter([], [], color='dodgerblue', label='0')
plt.scatter([], [], color='goldenrod', label='1')
plt.legend()


Out[8]:
<matplotlib.legend.Legend at 0x113573080>

Writing the KNN function


In [9]:
def knn(pt, X, y, k):
    """
    Returns the class predicted for pt using the knn algorithm
    """
    
    dists = []
    # Put features and outputs in the same array:
    data = np.stack((X[:,0], X[:,1], y), axis=1)
    # Compute the distance between pt and all the data points:
    dists = [euclidean(pt, [data_pt[0],data_pt[1]]) for data_pt in data]
    # Add the corresponding outputs:
    dists_cl = np.stack((np.array(dists), data[:,2]), axis=1)
    # Order by ascending distance:
    dists_cl_ordered = dict(dists_cl)
    # Simplest case, less fuss than k>1:
    if k==1:
        one_nn_class = dists_cl_ordered[min(dists_cl_ordered.keys())]
        return int(one_nn_class)
    if k>1:
        nearests = []
        # Recursively pop the 1-nn from the distances:
        for i in range(k):
            curr_min = min(dists_cl_ordered.keys())
            nearests.append([curr_min,dists_cl_ordered[curr_min]])
            dists_cl_ordered.pop(curr_min, None)
        # Get the majoritary class among the popped elements:
        nearests_cl = np.array(nearests)[:,1].astype(int)
        k_nn_class = np.argmax(np.bincount(nearests_cl))
        return int(k_nn_class)

Applying the KNN function to the moons dataset


In [25]:
# Create the grid:
resolution = 0.05
x1_min = min(X[:,0]) - 0.5
x1_max = max(X[:,0]) + 0.5
x2_min = min(X[:,1]) - 0.5
x2_max = max(X[:,1]) + 0.5
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),\
                       np.arange(x2_min, x2_max, resolution))
pts_domain = np.stack((np.ravel(xx1), np.ravel(xx2)), axis=1)

In [26]:
pts_domain.shape


Out[26]:
(5246, 2)

In [29]:
k_knn = 3
for ptdx, ptdy in tqdm(zip(pts_domain[:,0], pts_domain[:,1])):
    plt.scatter(ptdx, ptdy, color=colz[knn([ptdx, ptdy], X, y, k_knn)],\
                marker='.', s=3, \
               alpha=0.7)
for pt, cl in zip(X, y):
    plt.scatter(pt[0],pt[1], color=colz[cl])

plt.scatter([], [], color='dodgerblue', label='0')
plt.scatter([], [], color='goldenrod', label='1')
plt.legend()

plt.xlim([x1_min, x1_max])
plt.ylim([x2_min, x2_max])
plt.xlabel(r'$X_1$', fontsize=12)
plt.ylabel(r'$X_2$', fontsize=12)


5246it [00:42, 122.62it/s]
Out[29]:
<matplotlib.text.Text at 0x198be2cf8>