NumPy Broadcasting with kNN

  • Bowen Li
  • 2016/09/27
  • Revised: 2017/04/01

This notebook is to learn the details of k-Nearest Neighbors (kNN) in the great talk: "Losing your Loops Fast Numerical Computing with NumPy" by Jake VanderPlas. :-)


In [1]:
import numpy as np

In [2]:
# 1000 points in 3 dimensions
x = np.random.random((1000, 3))
x.shape


Out[2]:
(1000, 3)

In [3]:
x


Out[3]:
array([[ 0.36405342,  0.864484  ,  0.23917426],
       [ 0.36582774,  0.70996984,  0.64633359],
       [ 0.4640329 ,  0.03019656,  0.77808203],
       ..., 
       [ 0.36222717,  0.02370646,  0.86896656],
       [ 0.07332455,  0.67144581,  0.01419984],
       [ 0.00279737,  0.73408481,  0.95834866]])

In [4]:
x[0]


Out[4]:
array([ 0.36405342,  0.864484  ,  0.23917426])

In [5]:
# Broadcasting to find pairwise difference
x.reshape((1000, 1, 3))


Out[5]:
array([[[ 0.36405342,  0.864484  ,  0.23917426]],

       [[ 0.36582774,  0.70996984,  0.64633359]],

       [[ 0.4640329 ,  0.03019656,  0.77808203]],

       ..., 
       [[ 0.36222717,  0.02370646,  0.86896656]],

       [[ 0.07332455,  0.67144581,  0.01419984]],

       [[ 0.00279737,  0.73408481,  0.95834866]]])

In [6]:
x.reshape((1000, 1, 3))[0]


Out[6]:
array([[ 0.36405342,  0.864484  ,  0.23917426]])

In [7]:
pair_diff = x.reshape((1000, 1, 3)) - x
pair_diff.shape


Out[7]:
(1000, 1000, 3)

In [8]:
pair_diff


Out[8]:
array([[[ 0.        ,  0.        ,  0.        ],
        [-0.00177432,  0.15451415, -0.40715933],
        [-0.09997948,  0.83428744, -0.53890777],
        ..., 
        [ 0.00182625,  0.84077754, -0.6297923 ],
        [ 0.29072887,  0.19303819,  0.22497442],
        [ 0.36125604,  0.13039919, -0.7191744 ]],

       [[ 0.00177432, -0.15451415,  0.40715933],
        [ 0.        ,  0.        ,  0.        ],
        [-0.09820517,  0.67977328, -0.13174844],
        ..., 
        [ 0.00360057,  0.68626338, -0.22263297],
        [ 0.29250319,  0.03852404,  0.63213375],
        [ 0.36303036, -0.02411497, -0.31201507]],

       [[ 0.09997948, -0.83428744,  0.53890777],
        [ 0.09820517, -0.67977328,  0.13174844],
        [ 0.        ,  0.        ,  0.        ],
        ..., 
        [ 0.10180573,  0.0064901 , -0.09088453],
        [ 0.39070835, -0.64124925,  0.76388219],
        [ 0.46123553, -0.70388825, -0.18026663]],

       ..., 
       [[-0.00182625, -0.84077754,  0.6297923 ],
        [-0.00360057, -0.68626338,  0.22263297],
        [-0.10180573, -0.0064901 ,  0.09088453],
        ..., 
        [ 0.        ,  0.        ,  0.        ],
        [ 0.28890262, -0.64773934,  0.85476672],
        [ 0.35942979, -0.71037835, -0.0893821 ]],

       [[-0.29072887, -0.19303819, -0.22497442],
        [-0.29250319, -0.03852404, -0.63213375],
        [-0.39070835,  0.64124925, -0.76388219],
        ..., 
        [-0.28890262,  0.64773934, -0.85476672],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.07052717, -0.062639  , -0.94414882]],

       [[-0.36125604, -0.13039919,  0.7191744 ],
        [-0.36303036,  0.02411497,  0.31201507],
        [-0.46123553,  0.70388825,  0.18026663],
        ..., 
        [-0.35942979,  0.71037835,  0.0893821 ],
        [-0.07052717,  0.062639  ,  0.94414882],
        [ 0.        ,  0.        ,  0.        ]]])

In [9]:
# Aggregate to find pairwise distance
pair_diff2 = pair_diff ** 2
pair_diff2.shape


Out[9]:
(1000, 1000, 3)

In [10]:
pair_diff2


Out[10]:
array([[[  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
        [  3.14820488e-06,   2.38746240e-02,   1.65778716e-01],
        [  9.99589709e-03,   6.96035530e-01,   2.90421581e-01],
        ..., 
        [  3.33519043e-06,   7.06906868e-01,   3.96638340e-01],
        [  8.45232756e-02,   3.72637440e-02,   5.06134910e-02],
        [  1.30505930e-01,   1.70039483e-02,   5.17211814e-01]],

       [[  3.14820488e-06,   2.38746240e-02,   1.65778716e-01],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
        [  9.64425447e-03,   4.62091718e-01,   1.73576517e-02],
        ..., 
        [  1.29640937e-05,   4.70957431e-01,   4.95654409e-02],
        [  8.55581148e-02,   1.48410153e-03,   3.99593076e-01],
        [  1.31791044e-01,   5.81531611e-04,   9.73534052e-02]],

       [[  9.99589709e-03,   6.96035530e-01,   2.90421581e-01],
        [  9.64425447e-03,   4.62091718e-01,   1.73576517e-02],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
        ..., 
        [  1.03644074e-02,   4.21213885e-05,   8.25999827e-03],
        [  1.52653017e-01,   4.11200595e-01,   5.83515999e-01],
        [  2.12738212e-01,   4.95458669e-01,   3.24960583e-02]],

       ..., 
       [[  3.33519043e-06,   7.06906868e-01,   3.96638340e-01],
        [  1.29640937e-05,   4.70957431e-01,   4.95654409e-02],
        [  1.03644074e-02,   4.21213885e-05,   8.25999827e-03],
        ..., 
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
        [  8.34647234e-02,   4.19566259e-01,   7.30626149e-01],
        [  1.29189777e-01,   5.04637400e-01,   7.98915955e-03]],

       [[  8.45232756e-02,   3.72637440e-02,   5.06134910e-02],
        [  8.55581148e-02,   1.48410153e-03,   3.99593076e-01],
        [  1.52653017e-01,   4.11200595e-01,   5.83515999e-01],
        ..., 
        [  8.34647234e-02,   4.19566259e-01,   7.30626149e-01],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
        [  4.97408241e-03,   3.92364494e-03,   8.91416996e-01]],

       [[  1.30505930e-01,   1.70039483e-02,   5.17211814e-01],
        [  1.31791044e-01,   5.81531611e-04,   9.73534052e-02],
        [  2.12738212e-01,   4.95458669e-01,   3.24960583e-02],
        ..., 
        [  1.29189777e-01,   5.04637400e-01,   7.98915955e-03],
        [  4.97408241e-03,   3.92364494e-03,   8.91416996e-01],
        [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00]]])

In [11]:
(-0.34797723) ** 2


Out[11]:
0.12108815259847291

In [12]:
pair_dist = (pair_diff ** 2).sum(2)
pair_dist.shape


Out[12]:
(1000, 1000)

In [13]:
pair_dist


Out[13]:
array([[ 0.        ,  0.18965649,  0.99645301, ...,  1.10354854,
         0.17240051,  0.66472169],
       [ 0.18965649,  0.        ,  0.48909362, ...,  0.52053584,
         0.48663529,  0.22972598],
       [ 0.99645301,  0.48909362,  0.        , ...,  0.01866653,
         1.14736961,  0.74069294],
       ..., 
       [ 1.10354854,  0.52053584,  0.01866653, ...,  0.        ,
         1.23365713,  0.64181634],
       [ 0.17240051,  0.48663529,  1.14736961, ...,  1.23365713,
         0.        ,  0.90031472],
       [ 0.66472169,  0.22972598,  0.74069294, ...,  0.64181634,
         0.90031472,  0.        ]])

In [14]:
# Set diagonal to infinity to skip self-neighbors
i = np.arange(1000)
pair_dist[i ,i] = np.inf

In [15]:
pair_dist


Out[15]:
array([[        inf,  0.18965649,  0.99645301, ...,  1.10354854,
         0.17240051,  0.66472169],
       [ 0.18965649,         inf,  0.48909362, ...,  0.52053584,
         0.48663529,  0.22972598],
       [ 0.99645301,  0.48909362,         inf, ...,  0.01866653,
         1.14736961,  0.74069294],
       ..., 
       [ 1.10354854,  0.52053584,  0.01866653, ...,         inf,
         1.23365713,  0.64181634],
       [ 0.17240051,  0.48663529,  1.14736961, ...,  1.23365713,
                inf,  0.90031472],
       [ 0.66472169,  0.22972598,  0.74069294, ...,  0.64181634,
         0.90031472,         inf]])

In [16]:
# Print the indices of the nearest neighbors (for the first 10 points)
neighbors = np.argmin(pair_dist, 1)
print(neighbors[:10])


[298 917 983  47 687 929   0 139 651  76]

In [17]:
# k-Nearest Neighbors in summary

# Broadcasting to find pairwise difference
pair_diff = x.reshape((1000, 1, 3)) - x

# Aggregate to find pairwise distance
pair_dist = (pair_diff ** 2).sum(2)

# Set diagonal to infinity to skip self-neighbors
i = np.arange(1000)
pair_dist[i ,i] = np.inf

# Obtain the indices of the nearest neighbors
print(pair_dist)

# Find the nearest neighbors for each points
neighbors = np.argmin(pair_dist, 1)
print(neighbors[:10])


[[        inf  0.18965649  0.99645301 ...,  1.10354854  0.17240051
   0.66472169]
 [ 0.18965649         inf  0.48909362 ...,  0.52053584  0.48663529
   0.22972598]
 [ 0.99645301  0.48909362         inf ...,  0.01866653  1.14736961
   0.74069294]
 ..., 
 [ 1.10354854  0.52053584  0.01866653 ...,         inf  1.23365713
   0.64181634]
 [ 0.17240051  0.48663529  1.14736961 ...,  1.23365713         inf
   0.90031472]
 [ 0.66472169  0.22972598  0.74069294 ...,  0.64181634  0.90031472
          inf]]
[298 917 983  47 687 929   0 139 651  76]

In [ ]: