We have used the Annoy (Approximate Nearest Neighbors Oh Yeah) library for ANN indexing. Annoy constructs an ensemble of random projection index trees by recursively splitting the data space into subspaces using random split hyperplanes.
We will illustrate in a two-dimensional space how the ANN index is constructed and how it can be used for speeding up nearest neighbor queries below.
This code is based on erikbern/ann-presentation.
In [1]:
%matplotlib inline
import descartes
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import nxpd
import seaborn as sns
import shapely.geometry as sg
import shapely.ops as so
import sklearn.datasets
In [2]:
sns.set_style('white')
MAX_LEAF_SIZE = 25
In [3]:
class TreeNode:
def __init__(self, points, parent=None, surface=None, level=None):
self.parent = parent
self.left = None
self.right = None
self.level = parent.level + 1 if parent is not None else 0
self.points = points
self.surface = (
surface if surface is not None else
sg.Polygon([(20, 20), (20, -20), (-20, -20), (-20, 20)])
)
# split the internal node
if len(points) > MAX_LEAF_SIZE and level != 0:
# compute a split hyperplane between two random points
p1, p2 = points[np.random.choice(points.shape[0], 2, False), :]
v = p2 - p1
m = np.mean([p1, p2], axis=0)
a = np.dot(v, m)
v_perp = np.asarray([v[1], -v[0]])
# assign the points to child nodes
idx_left = np.where(np.dot(points, v) - a >= 0)[0]
idx_right = np.setdiff1d(np.arange(len(points)), idx_left, True)
surface_left = sg.Polygon(
np.array([m + v_perp * 1e6, m + v * 1e6, m - v_perp * 1e6])
).intersection(self.surface)
surface_right = sg.Polygon(
np.array([m + v_perp * 1e6, m - v * 1e6, m - v_perp * 1e6])
).intersection(self.surface)
# recursively split into child nodes
child_level = level - 1 if level is not None else None
self.left = TreeNode(points[idx_left], self,
surface_left, child_level)
self.right = TreeNode(points[idx_right], self,
surface_right, child_level)
In [4]:
def dfs_tree(root):
stack = [root]
while stack:
node = stack.pop()
yield node
if node.right is not None:
stack.append(node.right)
if node.left is not None:
stack.append(node.left)
In [5]:
def plot_subspaces(node, query=None, colormap='gist_rainbow',
level=None, filename=None):
fig, ax = plt.subplots(figsize=(7, 7))
# plot the points
ax.scatter(node.points[:, 0], node.points[:, 1],
c='black', marker='.', zorder=5)
# plot the query point
if query is not None:
ax.scatter(query[0], query[1], s=300, c='red', marker='X',
edgecolor='black', linewidth='3', zorder=10)
query_point = sg.Point(*query)
else:
query_point = None
# plot the subspaces
max_depth = max([n.level for n in dfs_tree(node)]) + 1
max_nodes = sum([2 ** i for i in range(max_depth)])
colors = sns.color_palette(colormap, max_nodes)
def process_node(current_node, current_colors):
if query_point is not None and\
not current_node.surface.contains(query_point):
color = 'white'
else:
color = current_colors[len(current_colors) // 2]
if level is None or current_node.level <= level:
ax.add_patch(descartes.PolygonPatch(
current_node.surface, facecolor=color,
edgecolor='darkgray', zorder=0))
sep = len(current_colors) // 2
if current_node.left is not None:
process_node(current_node.left, current_colors[:sep])
if current_node.right is not None:
process_node(current_node.right, current_colors[sep + 1:])
process_node(node, colors)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.axis('off')
if filename is not None:
plt.savefig(filename, dpi=300, bbox_inches='tight', pad_inches=0)
plt.show()
plt.close()
In [6]:
def plot_ensemble(nodes, query, filename=None):
fig, ax = plt.subplots(figsize=(7, 7))
# plot the points
ax.scatter(nodes[0].points[:, 0], nodes[0].points[:, 1],
c='black', marker='.', zorder=5)
# plot the query point
ax.scatter(query[0], query[1], s=300, c='red', marker='X',
edgecolor='black', linewidth='3', zorder=10)
query_point = sg.Point(*query)
# get the query's subspace from all trees
subspaces = []
for node in nodes:
for current_node in dfs_tree(node):
is_leaf = current_node.left is None and current_node.right is None
if is_leaf and current_node.surface.contains(query_point):
subspaces.append(current_node.surface)
ax.add_patch(descartes.PolygonPatch(
so.cascaded_union(subspaces), facecolor='none', edgecolor='black',
hatch='x', linewidth=2.0, zorder=0))
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.axis('off')
if filename is not None:
plt.savefig(filename, dpi=300, bbox_inches='tight', pad_inches=0)
plt.show()
plt.close()
In [7]:
def plot_tree(node, query=None, colormap='gist_rainbow',
level=None, filename=None):
tree = nx.DiGraph()
query_point = sg.Point(*query) if query is not None else None
max_depth = max([n.level for n in dfs_tree(node)]) + 1
max_nodes = sum([2 ** i for i in range(max_depth)])
colors = sns.color_palette(colormap, max_nodes)
def process_node(current_node, current_colors):
color = current_colors[len(current_colors) // 2]
if level is None or current_node.level <= level:
is_leaf = ((current_node.left is None and
current_node.right is None) or
current_node.level == level)
attrs = {'style': 'filled', 'fontsize': 24, 'fontname': 'bold',
'label': current_node.points.shape[0] if is_leaf else '""',
'shape': 'circle' if is_leaf else 'square'}
if query_point is not None:
if not current_node.surface.contains(query_point):
attrs['fillcolor'] = '#ffffff' # white
else:
attrs['fillcolor'] = mcolors.rgb2hex(color)
attrs['penwidth'] = 10
else:
attrs['fillcolor'] = mcolors.rgb2hex(color)
tree.add_node(current_node, **attrs)
if current_node.parent is not None:
tree.add_edge(current_node.parent, current_node)
sep = len(current_colors) // 2
if current_node.left is not None:
process_node(current_node.left, current_colors[:sep])
if current_node.right is not None:
process_node(current_node.right, current_colors[sep + 1:])
process_node(node, colors)
if filename is not None:
nxpd.draw(tree, filename=filename, show=False,
args=['-Gsize=10,10!', '-Gratio=fill', '-Gdpi=100'])
return tree
In [8]:
def get_points():
X, y = sklearn.datasets.make_blobs(500, 2, centers=10, center_box=(-10, 10))
return X
In [9]:
np.random.seed(1)
points = get_points()
query = (-7, -2.8)
The data space is partitioned into two subspaces using a random split hyperplane.
In [10]:
np.random.seed(1)
root = TreeNode(points)
plot_subspaces(root, level=1, filename='index-subspaces_level1.pdf')
Each subspace can recursively be partitioned further.
In [11]:
plot_subspaces(root, filename='index-subspaces.pdf')
This can be represented as a binary index tree.
In [12]:
tree = plot_tree(root, filename='index-tree.pdf')
nxpd.draw(tree, show='ipynb')
Out[12]:
The data subspace to which a query item belongs is retrieved to find its nearest neighbors.
In [13]:
plot_subspaces(root, query, filename='index-subspaces_query.pdf')
This subspace can be found efficiently using the binary index tree.
In [14]:
tree = plot_tree(root, query, filename='index-tree_query.pdf')
nxpd.draw(tree, show='ipynb')
Out[14]:
A composite data subspace can be compiled from multiple index trees to obtain a better approximation.
In [15]:
trees = [root]
for _ in range(5):
trees.append(TreeNode(points))
plot_ensemble(trees, query, 'index-ensemble.pdf')