Prerequisites:
python 3.6, we use the f-string notation ( pep 498 ). Alternatively you can modify the code to use python 3.5-compatible formatting
Best way to get python is via Anaconda
You need to install tinydb: pip install tinydb
You need to build our python wrappers: make bind
You need faiss python wrappers
You need to build our FastText fork: cd fastText ; make
You need to get dataset from Extreme Classification Repo
The downloaded dataset has to be renamed to train.txt and test.txt (or you can fix the loader code)
Data should be in <root>/data, or you can change the paths in this script
In [ ]:
cd ..
system-related stuff
In [ ]:
import sys
import os
Add these directories to PYTHONPATH so that we can import them
In [ ]:
sys.path.append(os.path.abspath('./faiss/'))
sys.path.append(os.path.abspath('./python/'))
Import our functions to read Extreme-Repository data format and to transform it to FastText format
In [ ]:
from experiments.data import get_data
from misc.utils import to_ft, load_sift
Read exrepo format, transform it to FastText format (stored in ./data/LSHTC-FT/*.txt)
In [ ]:
X, Y, words_mask, labels_mask = get_data('./data/LSHTC', 'train', min_words=3, min_labels=3)
to_ft(X, Y, './data/LSHTC-FT/train.txt')
X, Y, *_ = get_data('./data/LSHTC', 'test', words_mask=words_mask, labels_mask=labels_mask)
to_ft(X, Y, './data/LSHTC-FT/test.txt')
Generate fasttext commands to be run. They will generate our MIPS dataset.
What they do is they first train a FastText model, and then instead of calling predict(),
we'll go through each example and write to disk three things:
correct answers
hidden vectors from the model (our QUERIES)
output-weight matrix (our BASE vectors)
In [ ]:
def make_cmd(*args, **kwargs):
args = ' '.join(args)
opts = ' '.join(f'-{k} {v}' for k, v in kwargs.items())
cmd = f'./fastText/fasttext {args} {opts}'
return cmd.split()
train_cmd = make_cmd('supervised',
input = './data/LSHTC-FT/train.txt',
output = './data/LSHTC-FT/model.ft',
minCount = 5,
minCountLabel = 5,
lr = 0.1,
lrUpdateRate = 100,
dim = 256,
ws = 5,
epoch = 25,
neg = 25,
loss = 'ns',
thread = 8,
saveOutput = 1)
generate_cmd = make_cmd('to-fvecs',
'./data/LSHTC-FT/model.ft.bin',
'./data/LSHTC-FT/test.txt',
'./data/LSHTC-FT/fvecs')
Runs bash command prepared above
In [ ]:
import subprocess
subprocess.call(train_cmd)
subprocess.call(generate_cmd)
Outside imports
In [ ]:
import datetime
import json
import logging
import os
import sys
import time
import uuid
import numpy as np
from contextlib import contextmanager
from tinydb import TinyDB, where
Our imports
In [ ]:
sys.path.append(os.path.abspath('./faiss/'))
sys.path.append(os.path.abspath('./python/'))
import faiss
import mips
from experiments.data import get_data
from misc.utils import to_ft, load_sift
Some basic logging
In [ ]:
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
In [ ]:
@contextmanager
def timer():
""" Simple context manager to convieniently measure times between enter and exit """
class Clock:
elapsed = 0
t0 = time.time()
yield Clock
Clock.elapsed = time.time() - t0
def search(idx, data, k):
""" Search top-k items in `data` using `idx`
This is needed because currently our wrappers return 1-D arrays, so we need to reshape them
"""
D, I = idx.search(data, k)
D, I = D.reshape(-1, k), I.reshape(-1, k)
return D, I
def compute_p1(G, I):
""" Compute precision-at-1 for groundtruth `G` and predicted indices `I` """
p1 = 0.
for i, item in enumerate(I):
p1 += float(int(item) in G[i])
p1 /= len(G)
return p1
def test_idx(IdxClass, params, xb, xq, G, k=100):
""" Train and test the given Index class with given params
Use the provided base, query, and groundtruth vectors.
Index will predict top-`k` entries.
In case of failure, the exception is returned as string
This function returns a Report dictionary
"""
try:
idx = IdxClass(**params)
with timer() as train_t:
idx.train(xb)
idx.add(xb)
with timer() as search_t:
_, I = search(idx, xq, k)
p1 = compute_p1(G, I[:, 0])
report = make_report(IdxClass, params, p1, train_t.elapsed, search_t.elapsed)
except Exception as e:
print('FAILED: ' + str(e))
report = str(e)
return report
def now():
""" Helper function to format current timestamp """
return datetime.datetime.fromtimestamp(time.time()).strftime("%d-%m-%y %H:%M:%S")
def make_report(IdxClass, params, p1, train_t, search_t):
""" Create a Report dictionary for given set of parameters """
return {
'ID': uuid.uuid4().hex,
'algo': IdxClass.__name__,
'params': params,
'p1': p1,
'train_t': train_t,
'search_t': search_t
}
def add_result(r):
""" Add result `r` to the database """
if isinstance(r, dict):
algo, params, p1, t = r['algo'], r['params'], r['p1'], r['search_t']
rep = f'(params={params}, p1={p1:.2f}, t={t:.2f})'
else:
rep = r
logger.info(f'Adding: {rep}')
def result_adder(doc):
doc['results'].append(r)
DB.update(result_adder, where('ID') == ID)
def test(IdxClass, **params):
""" Even higher-level wrapper for testing """
return test_idx(IdxClass, params, xb, xq, G, k=100)
Create the database we'll store the results in
In [ ]:
# !rm ./data/results/ad-hoc-db.json
DB = TinyDB('./data/results/ad-hoc-db.json')
ID = uuid.uuid4().hex
info = dict(
ID = ID,
name = 'ad-hoc-results',
date = now(),
results = []
)
DB.insert(info)
Load the generated dataset
In [ ]:
xq = load_sift('./data/LSHTC-FT/fvecs.hid.fvecs', dtype=np.float32)
xb = load_sift('./data/LSHTC-FT/fvecs.wo.fvecs', dtype=np.float32)
_n, d, c = xq.shape[0], xq.shape[1], xb.shape[0]
We do not run on all queries, but on random subset of lIMIT queries
In [ ]:
LIMIT = 250_000
inds = np.random.choice(np.arange(_n), LIMIT, replace=False)
xq = xq[inds, :]
xq = np.copy(np.ascontiguousarray(xq), order='C')
xb = np.copy(np.ascontiguousarray(xb), order='C')
n = xq.shape[0]
Load groundtruth
In [ ]:
G = []
for line in open('./data/LSHTC-FT/fvecs.labels.txt'):
G.append({int(y) for y in line.split()})
G = [G[idx] for idx in inds]
Logging
In [ ]:
logger.info(f"Loaded dataset of {_n:_}, {d:_}-dimensionsl queries (examples), but limiting to {LIMIT:_} queries")
logger.info(f"The dataset contains {c:_} classes, and more than one class can be positive")
In [ ]:
class IVFIndex:
def __init__(self, d, size, nprobe):
self.index = faiss.index_factory(d, f"IVF{size},Flat", faiss.METRIC_INNER_PRODUCT)
self.index.nprobe = nprobe
def __getattr__(self, name):
return getattr(self.index, name)
class KMeansIndex:
def __init__(self, d, layers, nprobe, m, U):
self.aug = mips.MipsAugmentationShrivastava(d, m, U)
self.index = mips.IndexHierarchicKmeans(d, layers, nprobe, self.aug, False)
def __getattr__(self, name):
return getattr(self.index, name)
class FlatIndex:
def __init__(self, d):
self.index = faiss.IndexFlatIP(d)
def __getattr__(self, name):
return getattr(self.index, name)
class AlshIndex:
def __init__(self, d: int, L: int, K: int, r: int, m: int, U: int):
self.aug = mips.MipsAugmentationShrivastava(d, m, U)
self.index = mips.AlshIndex(d, L, K, r, self.aug)
def __getattr__(self, name):
return getattr(self.index, name)
class QuantIndex:
def __init__(self, dim: int, subspace_count: int, centroid_count: int):
self.index = mips.IndexSubspaceQuantization(dim, subspace_count, centroid_count)
def __getattr__(self, name):
return getattr(self.index, name)
IVF
In [ ]:
for size in [4096]:
for nprobe in [1, 16, 32, 64, 128]:
add_result(
test(
IVFIndex, d=d, size=size, nprobe=nprobe))
Kmeans
In [ ]:
for layers in [2]:
for nprobe in [1, 16, 32, 64, 128]:
add_result(
test(
KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))
IVF
In [ ]:
for size in [4096]:
for nprobe in [256, 512]:
add_result(
test(
IVFIndex, d=d, size=size, nprobe=nprobe))
Kmeans
In [ ]:
for layers in [2]:
for nprobe in [256, 512]:
add_result(
test(
KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))
In [ ]:
for layers in [3]:
for nprobe in [1, 16, 32, 64, 128]:
add_result(
test(
KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))
In [ ]:
for L in [2,4,8,16,32]:
for K in [2, 4, 8, 16, 32]:
add_result(
test(
AlshIndex, d=d, L=L, K=K, m=5, r=10, U=0.85))
In [ ]:
for subspace in [2,4,8]:
for centroid in [2, 4, 8, 16, 32]:
add_result(
test(
QuantIndex, dim=d, subspace_count=subspace, centroid_count=centroid))
In [ ]:
add_result(
test(
FlatIndex, d=d))
In [ ]:
DB.all()
In [ ]: