This notebook is to learn the details of k-Nearest Neighbors (kNN) in the great talk: "Losing your Loops Fast Numerical Computing with NumPy" by Jake VanderPlas. :-)
In [1]:
import numpy as np
In [2]:
# 1000 points in 3 dimensions
x = np.random.random((1000, 3))
x.shape
Out[2]:
In [3]:
x
Out[3]:
In [4]:
x[0]
Out[4]:
In [5]:
# Broadcasting to find pairwise difference
x.reshape((1000, 1, 3))
Out[5]:
In [6]:
x.reshape((1000, 1, 3))[0]
Out[6]:
In [7]:
pair_diff = x.reshape((1000, 1, 3)) - x
pair_diff.shape
Out[7]:
In [8]:
pair_diff
Out[8]:
In [9]:
# Aggregate to find pairwise distance
pair_diff2 = pair_diff ** 2
pair_diff2.shape
Out[9]:
In [10]:
pair_diff2
Out[10]:
In [11]:
(-0.34797723) ** 2
Out[11]:
In [12]:
pair_dist = (pair_diff ** 2).sum(2)
pair_dist.shape
Out[12]:
In [13]:
pair_dist
Out[13]:
In [14]:
# Set diagonal to infinity to skip self-neighbors
i = np.arange(1000)
pair_dist[i ,i] = np.inf
In [15]:
pair_dist
Out[15]:
In [16]:
# Print the indices of the nearest neighbors (for the first 10 points)
neighbors = np.argmin(pair_dist, 1)
print(neighbors[:10])
In [17]:
# k-Nearest Neighbors in summary
# Broadcasting to find pairwise difference
pair_diff = x.reshape((1000, 1, 3)) - x
# Aggregate to find pairwise distance
pair_dist = (pair_diff ** 2).sum(2)
# Set diagonal to infinity to skip self-neighbors
i = np.arange(1000)
pair_dist[i ,i] = np.inf
# Obtain the indices of the nearest neighbors
print(pair_dist)
# Find the nearest neighbors for each points
neighbors = np.argmin(pair_dist, 1)
print(neighbors[:10])
In [ ]: