KD-trees

Question 1

Screenshot taken from Coursera

Question 2

Screenshot taken from Coursera

Answer


In [74]:
import numpy as np

x1 = np.array([-1.58, 0.91, -0.73, -4.22, 4.19, -0.33])
x2 = np.array([-2.01, 3.98, 4.00, 1.16, -2.02, 2.15])

x = np.vstack((x1, x2)).T
x


Out[74]:
array([[-1.58, -2.01],
       [ 0.91,  3.98],
       [-0.73,  4.  ],
       [-4.22,  1.16],
       [ 4.19, -2.02],
       [-0.33,  2.15]])

In [89]:
# Mid range of x1
x1_midrange = (x1.max() + x1.min())/2
x1_midrange


Out[89]:
-0.01499999999999968

In [90]:
def get_mid_range(data, column=0):
    """ Get midrange of data by column
    - x1: column=0
    - x2: column=1
    """
    midrange = (data[:, column].max() + data[:, column].min())/2
    return midrange      

def split_by(x, value, column=0):
    """ Split x array by value and column
    - x1: column=0
    - x2: column=1
    """
    split1 = x[x[:, column] <= value]
    split2 = x[x[:, column] > value]
    return split1, split2

In [91]:
x1_midrange = get_mid_range(x)
x1_split1, x1_split2 = split_by(x, x1_midrange)

# Split values of x1
x1_split1


Out[91]:
array([[-1.58, -2.01],
       [-0.73,  4.  ],
       [-4.22,  1.16],
       [-0.33,  2.15]])

In [92]:
# Split values of x1
x1_split2


Out[92]:
array([[ 0.91,  3.98],
       [ 4.19, -2.02]])

Question 3

Screenshot taken from Coursera

Answer


In [93]:
# Mid range of x2 for the 1st split
# x1_split1_x2_midrange = (x1_split1[:, 1].max() + x1_split1[:, 1].min())/2
x1_split1_x2_midrange = get_mid_range(x1_split1, column=1)

print x1_split1_x2_midrange


0.995

In [94]:
# # Mid range of x2 for 2nd split
x1_split2_x2_midrange = get_mid_range(x1_split2, column=1)
print x1_split2_x2_midrange


0.98

Question 4

Screenshot taken from Coursera

Answer


In [95]:
x1_split1_x2_split1, x1_split1_x2_split2 = split_by(x1_split1, x2_x1_split1_midrange, column=1)
x1_split1_x2_split1


Out[95]:
array([[-1.58, -2.01]])

In [97]:
# node still has 3 data points
# continue to split

x1_split1_x2_split2


Out[97]:
array([[-0.73,  4.  ],
       [-4.22,  1.16],
       [-0.33,  2.15]])

In [98]:
x1_split1_x2_split2_midrange = get_mid_range(x1_split1_x2_split2, column=1)
x1_split1_x2_split2_midrange


Out[98]:
2.5800000000000001

In [99]:
x1_split1_x2_split2_x2_split1, x1_split1_x2_split2_x2_split2 = split_by(x1_split1_x2_split2, 
                                                                        x1_split1_x2_split2_midrange, 
                                                                        column=1)
x1_split1_x2_split2_x2_split1
# Continue to split


Out[99]:
array([[-4.22,  1.16],
       [-0.33,  2.15]])

In [100]:
x1_split1_x2_split2_x2_split2


Out[100]:
array([[-0.73,  4.  ]])

In [102]:
x1_split1_x2_split2_x2_split1_midrange = get_mid_range(x1_split1_x2_split2_x2_split1, column=1)
x1_split1_x2_split2_x2_split1_midrange


Out[102]:
1.6549999999999998
  • After split x1_split1_x2_split2_x2_split1 by 1.654, then we will have 2 more leaves. Data point 4: [-4.22, 1.16] will be the leaves contain the query point (-3, 1.5)

Question 5

Screenshot taken from Coursera

Answer


In [86]:
x1_split2_x2_split1, x1_split2_x2_split2 = split_by(x1_split2, x1_split2_x2_midrange, column=1)
x1_split2_x2_split1


Out[86]:
array([[ 4.19, -2.02]])

In [87]:
x1_split2_x2_split2


Out[87]:
array([[ 0.91,  3.98]])