In [1]:
import json
import numpy as np
import numpy.random as rd
from ipywidgets import widgets
from IPython.display import HTML, Javascript, display
In [2]:
# %load covertree.py
# File: covertree.py
# Date of creation: 05/04/07
# Copyright (c) 2007, Thomas Kollar <tkollar@csail.mit.edu>
# Copyright (c) 2011, Nil Geisweiller <ngeiswei@gmail.com>
# All rights reserved.
#
# This is a class for the cover tree nearest neighbor algorithm. For
# more information please refer to the technical report entitled "Fast
# Nearest Neighbors" by Thomas Kollar or to "Cover Trees for Nearest
# Neighbor" by John Langford, Sham Kakade and Alina Beygelzimer
#
# If you use this code in your research, kindly refer to the technical
# report.
import numpy as np
import operator
from random import choice
from heapq import nsmallest, heappush, heappop
from itertools import product
from collections import Counter
'''
try:
from collections import Counter
except ImportError: # Counter is not available in Python before v2.7
from recipe_576611_1 import Counter
try:
from joblib import Parallel, delayed
except ImportError:
pass
import cStringIO
'''
# method that returns true iff only one element of the container is True
def unique(container):
return Counter(container).get(True, 0) == 1
# the Node representation of the data
class Node:
# data is an array of values
def __init__(self, data=None, idx=None):
self.data = data
self.children = {} # dict mapping level and children
self.parent = None
self.idx = idx
# addChild adds a child to a particular Node and a given level i
def addChild(self, child, i):
try:
# in case i is not in self.children yet
if(child not in self.children[i]):
self.children[i].append(child)
except(KeyError):
self.children[i] = [child]
child.parent = self
# getChildren gets the children of a Node at a particular level
def getChildren(self, level):
retLst = [self]
try:
retLst.extend(self.children[level])
except(KeyError):
pass
return retLst
# like getChildren but does not return the parent
def getOnlyChildren(self, level):
try:
return self.children[level]
except(KeyError):
pass
return []
def removeConnections(self, level):
if(self.parent != None):
self.parent.children[level+1].remove(self)
self.parent = None
def __str__(self):
return str(self.data)
def __repr__(self):
return str(self.data)
class CoverTree:
#
# Overview: initalization method
#
# Input: distance function, root, maxlevel, minlevel, base, and
# for parallel support jobs and min_len_parallel. Here root is a
# point, maxlevel is the largest number that we care about
# (e.g. base^(maxlevel) should be our maximum number), just as
# base^(minlevel) should be the minimum distance between Nodes.
#
# In case parallel is enabled (jobs > 1), min_len_parallel is the
# minimum number of elements at a given level to have their
# distances to the element to insert or query evaluated.
#
def __init__(self, distance, data=None, root = None, maxlevel = 10, base = 2,
jobs = 1, min_len_parallel = 100):
self.distance = distance
self.root = root
self.maxlevel = maxlevel
self.minlevel = maxlevel # the minlevel will adjust automatically
self.idx = 0
self.base = base
self.jobs = jobs
self.min_len_parallel = min_len_parallel
# for printDotty
self.__printHash__ = set()
if data is None:
data = []
for point in data:
self.insert(point)
@property
def size(self):
"Number of elements in the tree"
return self.idx
#
# Overview: insert an element p into the tree
#
# Input: p
# Output: nothing
#
def insert(self, p):
if self.root == None:
self.root = self._newNode(p)
else:
self._insert_iter(p)
def _newNode(self, *args, **kws):
kws['idx'] = self.idx
self.idx += 1
return Node(*args, **kws)
def __iter__(self):
"""
Breadth-first traversal of the nodes in the tree
Output:
- iterable of (idx, point)
"""
queue = [(self.maxlevel, self.root)]
observed = set()
while queue:
lvl, node = queue.pop(0)
if node not in observed:
yield node.idx, node.data
observed.add(node)
next_lvl = lvl - 1
if next_lvl < self.minlevel: continue
for child in node.getChildren(next_lvl):
queue.append((next_lvl, child))
def extend(self, iterable):
if isinstance(iterable, CoverTree):
getter = operator.itemgetter(1)
else:
getter = lambda x: x
for p in map(getter, iterable):
self.insert(p)
#
# Overview:insert an element p in to the cover tree
#
# Input: point p
#
# Output: nothing
#
def _insert_iter(self, p):
Qi_p_ds = [(self.root, self.distance(p, self.root.data))]
i = self.maxlevel
while True:
# get the children of the current level
# and the distance of the all children
Q_p_ds = self._getChildrenDist_(p, Qi_p_ds, i)
d_p_Q = self._min_ds_(Q_p_ds)
if d_p_Q == 0.0: # already there, no need to insert
return
elif d_p_Q > self.base**i: # the found parent should be right
break
else: # d_p_Q <= self.base**i, keep iterating
# find parent
if self._min_ds_(Qi_p_ds) <= self.base**i:
parent = choice([q for q, d in Qi_p_ds if d <= self.base**i])
pi = i
# construct Q_i-1
Qi_p_ds = [(q, d) for q, d in Q_p_ds if d <= self.base**i]
i -= 1
# insert p
parent.addChild(self._newNode(p), pi)
# update self.minlevel
self.minlevel = min(self.minlevel, pi-1)
def neighbors(self, point, radius):
"""
Overview: get the neighbors of `p` within distance `r`
Input:
- point :: a point
- radius :: float - the maximum (inclusive) distance
Output:
- [(i, n, d)] :: list of pairs (`index`, `point`, `float`) which are the point and it's distance to `p`
"""
def containsPoint(point, radius, node, level, dist=None):
if dist is None:
dist = self.distance(point, node.data)
# print level, point, dist, radius, radius + self.base**level
return dist <= radius + self.base**level
if self.root is None:
return []
result = set()
queue = [(self.maxlevel, self.root, self.distance(point, self.root.data))]
while queue:
level, node, dist = queue.pop(0)
if not containsPoint(point, radius, node, level, dist=dist):
continue
if dist <= radius:
result.add((node, dist))
next_level = level-1
if next_level < self.minlevel: continue
for child in node.getChildren(next_level):
if not child == node:
d = self.distance(point, child.data)
else:
d = dist
queue.append((next_level, child, d))
return map(lambda x: (x[0].idx, x[0].data, x[1]), result)
def contains(self, point, eps=0.00001):
"""
Ask if the cover tree contains a given point
Input:
- point :: the query point -- the point to search for
- eps :: double -- epsilon for distance comparison
Output:
- found :: bool -- indicates presence of point in Cover Tree
"""
nn = self.neighbors(point, eps)
nn = list(nn) # force the lazy calculation
if len(nn) == 1:
return True
elif len(nn) == 0:
return False
else: raise(ValueError, 'Found multiple results for {} with eps={}: {}'.format(point, eps, nn))
def knn(self, p, k):
"""
Get the `k` nearest neighbors of `point`
Input:
- point :: a point
- k :: positive int
Output:
- [(i, p, d)] :: list of length `k` of the index, point, and distance in the CT closest to input `point`
"""
Qi_p_ds = [(self.root, self.distance(p, self.root.data))]
for i in reversed(range(self.minlevel, self.maxlevel+1)):
Q_p_ds = self._getChildrenDist_(p, Qi_p_ds, i)
_, d_p_Q = self._kmin_p_ds_(k, Q_p_ds)[-1]
Qi_p_ds = [(q, d) for q, d in Q_p_ds if d <= d_p_Q + self.base**i]
res = map(lambda x: (x[0].idx, x[0].data, x[1]), Qi_p_ds)
return nsmallest(k, res, key=operator.itemgetter(2))
#
# Overview: get the children of cover set Qi at level i and the
# distances of them with point p
#
# Input: point p to compare the distance with Qi's children, and
# Qi_p_ds the distances of all points in Qi with p
#
# Output: the children of Qi and the distances of them with point
# p
#
def _getChildrenDist_(self, p, Qi_p_ds, i):
Q = sum([n.getOnlyChildren(i) for n, _ in Qi_p_ds], [])
Q_p_ds = [(q, self.distance(p, q.data)) for q in Q]
return Qi_p_ds + Q_p_ds
#
# Overview: get a list of pairs <point, distance> with the k-min distances
#
# Input: Input cover set Q, distances of all nodes of Q to some point
# Output: list of pairs
#
def _kmin_p_ds_(self, k, Q_p_ds):
return nsmallest(k, Q_p_ds, lambda x: x[1])
# return the minimum distance of Q_p_ds
def _min_ds_(self, Q_p_ds):
return self._kmin_p_ds_(1, Q_p_ds)[0][1]
# format the final result. If without_distance is True then it
# returns only a list of data points, other it return a list of
# pairs <point.data, distance>
def _result_(self, res, without_distance):
if without_distance:
return [p.data for p, _ in res]
else:
return [(p.data, d) for p, d in res]
#
# Overview: write to a file the dot representation
#
# Input: None
# Output:
#
def writeDotty(self, outputFile):
outputFile.write("digraph {\n")
self._writeDotty_rec(outputFile, [self.root], self.maxlevel)
outputFile.write("}")
#
# Overview:recursively build printHash (helper function for writeDotty)
#
# Input: C, i is the level
#
def _writeDotty_rec(self, outputFile, C, i):
if(i == self.minlevel):
return
children = []
for p in C:
childs = p.getChildren(i)
for q in childs:
outputFile.write("\"lev:" +str(i) + " "
+ str(p.data) + "\"->\"lev:"
+ str(i-1) + " "
+ str(q.data) + "\"\n")
children.extend(childs)
self._writeDotty_rec(outputFile, children, i-1)
'''
def __str__(self):
output = cStringIO.StringIO()
self.writeDotty(output)
return output.getvalue()
'''
# check if the tree satisfies all invariants
def _check_invariants(self):
return self._check_nesting() and \
self._check_covering_tree() and \
self._check_seperation()
# check if my_invariant is satisfied:
# C_i denotes the set of nodes at level i
# for all i, my_invariant(C_i, C_{i-1})
def _check_my_invariant(self, my_invariant):
C = [self.root]
for i in reversed(range(self.minlevel, self.maxlevel + 1)):
C_next = sum([p.getChildren(i) for p in C], [])
if not my_invariant(C, C_next, i):
print("At level", i, "the invariant", my_invariant, "is false")
return False
C = C_next
return True
# check if the invariant nesting is satisfied:
# C_i is a subset of C_{i-1}
def _nesting(self, C, C_next, _):
return set(C) <= set(C_next)
def _check_nesting(self):
return self._check_my_invariant(self._nesting)
# check if the invariant covering tree is satisfied
# for all p in C_{i-1} there exists a q in C_i so that
# d(p, q) <= base^i and exactly one such q is a parent of p
def _covering_tree(self, C, C_next, i):
return all(unique(self.distance(p.data, q.data) <= self.base**i
and p in q.getChildren(i)
for q in C)
for p in C_next)
def _check_covering_tree(self):
return self._check_my_invariant(self._covering_tree)
# check if the invariant seperation is satisfied
# for all p, q in C_i, d(p, q) > base^i
def _seperation(self, C, _, i):
return all(self.distance(p.data, q.data) > self.base**i
for p, q in product(C, C) if p != q)
def _check_seperation(self):
return self._check_my_invariant(self._seperation)
In [3]:
def json_numpy_serializer(o):
if isinstance(o, np.ndarray):
return o.tolist()
raise TypeError("{} of type {} is not JSON serializable".format(repr(o), type(o)))
def jsglobal(**params):
code = [];
for name, value in params.items():
jsdata = json.dumps(value, default=json_numpy_serializer)
code.append("window.{}={};".format(name, jsdata))
display(Javascript("\n".join(code)))
In [4]:
#from covertree import CoverTree
def l2(p, q):
return np.sqrt((p-q) @ (p-q))
def extract_node_data(n0, x, y, z, levels, links):
for k in n0.children.keys():
for n in n0.children[k]:
x.append(n.data[0])
y.append(n.data[1])
z.append(n.data[2])
levels.append(k)
links.append(n0.data)
links.append(n.data)
extract_node_data(n, x, y, z, levels, links)
num_pts = 100
x = np.reshape(rd.randn(3*num_pts), (num_pts, 3))
x = np.apply_along_axis(lambda v: v / np.sqrt(np.sum(v**2)), 1, x)
cube_ct = CoverTree(l2)
for i in range(num_pts):
cube_ct.insert(x[i, :])
x = []
y = []
z = []
r = {}
levels = []
links = []
x.append(cube_ct.root.data[0])
y.append(cube_ct.root.data[1])
z.append(cube_ct.root.data[2])
levels.append(cube_ct.maxlevel)
r['min'] = cube_ct.minlevel
r['max'] = cube_ct.maxlevel
extract_node_data(cube_ct.root, x, y, z, levels, links)
pos = np.vstack([x, y, z]).T
jsglobal(POS=pos)
jsglobal(LINKS=links)
jsglobal(R=r)
jsglobal(LEVELS=levels)
In [5]:
%%javascript
// window.location.reload(true)
// Loading the compiled MathBox bundle.
require.config({
baseUrl: '',
paths: {
mathBox: '../static/mathbox/build/mathbox-bundle'
}
});
// Helper function that setups WebGL context and initializes MathBox.
window.with_mathbox = function(element, func) {
require(['mathBox'], function(){
var mathbox = mathBox({
plugins: ['core', 'controls', 'cursor', 'mathbox'],
controls: { klass: THREE.OrbitControls },
mathbox: {inspect: false},
element: element[0],
loop: {start: false},
});
var three = mathbox.three;
three.renderer.setClearColor(new THREE.Color(0xFFFFFF), 1.0);
three.camera.position.set(-1, 1, 2);
three.controls.noKeys = true;
three.element.style.height = "400px";
three.element.style.width = "100%";
function isInViewport(element) {
var rect = element.getBoundingClientRect();
var html = document.documentElement;
var w = window.innerWidth || html.clientWidth;
var h = window.innerHeight || html.clientHeight;
return rect.top < h && rect.left < w && rect.bottom > 0 && rect.right > 0;
}
// Running update/render loop only for visible plots.
var intervalId = setInterval(function(){
if (three.element.offsetParent === null) {
clearInterval(intervalId);
three.destroy();
return;
}
var visible = isInViewport(three.canvas);
if (three.Loop.running != visible) {
visible? three.Loop.start() : three.Loop.stop();
}
}, 100);
func(mathbox);
window.dispatchEvent(new Event('resize'));
})
}
In [6]:
%%javascript
with_mathbox(element, function(mathbox) {
var view = mathbox.cartesian({},{rotation:(t)=>[0, t*0.02, 0]})
.grid({axes: [1, 3]})
view.array({
width: LEVELS.length,
expr: function (emit, i, time) {
if ((time % (R['max'] - R['min'] + 10))> R['max']-LEVELS[i])
emit(POS[i][0], POS[i][1], POS[i][2]);
},
channels: 3
});
// Now we can see the data on JS side!
view.point({color:"#55a", size: 10});
//view.array({width: LINKS.length/2, items: 2, channels: 3, data: LINKS, live: false}).vector({color: 0x4444ff, width: 1});
})
In [7]: