Prerequisites:

  • python 3.6, we use the f-string notation ( pep 498 ). Alternatively you can modify the code to use python 3.5-compatible formatting

  • Best way to get python is via Anaconda

  • You need to install tinydb: pip install tinydb

  • You need to build our python wrappers: make bind

  • You need faiss python wrappers

  • You need to build our FastText fork: cd fastText ; make

  • You need to get dataset from Extreme Classification Repo

  • The downloaded dataset has to be renamed to train.txt and test.txt (or you can fix the loader code)

  • Data should be in <root>/data, or you can change the paths in this script


In [ ]:
cd ..

Data

system-related stuff


In [ ]:
import sys
import os

Add these directories to PYTHONPATH so that we can import them


In [ ]:
sys.path.append(os.path.abspath('./faiss/'))
sys.path.append(os.path.abspath('./python/'))

Import our functions to read Extreme-Repository data format and to transform it to FastText format


In [ ]:
from experiments.data import get_data
from misc.utils import to_ft, load_sift

Read exrepo format, transform it to FastText format (stored in ./data/LSHTC-FT/*.txt)


In [ ]:
X, Y, words_mask, labels_mask = get_data('./data/LSHTC', 'train', min_words=3, min_labels=3)
to_ft(X, Y, './data/LSHTC-FT/train.txt')

X, Y, *_ = get_data('./data/LSHTC', 'test', words_mask=words_mask, labels_mask=labels_mask)
to_ft(X, Y, './data/LSHTC-FT/test.txt')

Generate fasttext commands to be run. They will generate our MIPS dataset.

What they do is they first train a FastText model, and then instead of calling predict(), we'll go through each example and write to disk three things:

  • correct answers

  • hidden vectors from the model (our QUERIES)

  • output-weight matrix (our BASE vectors)


In [ ]:
def make_cmd(*args, **kwargs):
    args = ' '.join(args)
    opts = ' '.join(f'-{k} {v}' for k, v in kwargs.items())
    cmd  = f'./fastText/fasttext {args} {opts}'
    
    return cmd.split()

train_cmd = make_cmd('supervised', 
                     input         = './data/LSHTC-FT/train.txt',
                     output        = './data/LSHTC-FT/model.ft',
                     minCount      = 5,
                     minCountLabel = 5,
                     lr            = 0.1,
                     lrUpdateRate  = 100,
                     dim           = 256,
                     ws            = 5,
                     epoch         = 25,
                     neg           = 25,
                     loss          = 'ns',
                     thread        = 8,
                     saveOutput    = 1)

generate_cmd = make_cmd('to-fvecs',
                        './data/LSHTC-FT/model.ft.bin',
                        './data/LSHTC-FT/test.txt',
                        './data/LSHTC-FT/fvecs')

Runs bash command prepared above


In [ ]:
import subprocess

subprocess.call(train_cmd)
subprocess.call(generate_cmd)

Tests

We're gonna run some Indices on the data that we generated above

Outside imports


In [ ]:
import datetime
import json
import logging
import os
import sys
import time
import uuid

import numpy as np

from contextlib import contextmanager
from tinydb import TinyDB, where

Our imports


In [ ]:
sys.path.append(os.path.abspath('./faiss/'))
sys.path.append(os.path.abspath('./python/'))

import faiss
import mips

from experiments.data import get_data
from misc.utils import to_ft, load_sift

Some basic logging


In [ ]:
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()

Simple utils functions


In [ ]:
@contextmanager
def timer():
    """ Simple context manager to convieniently measure times between enter and exit """
    
    class Clock:
        elapsed = 0
        
    t0 = time.time()
    
    yield Clock
    
    Clock.elapsed = time.time() - t0


def search(idx, data, k):
    """ Search top-k items in `data` using `idx` 
    
    This is needed because currently our wrappers return 1-D arrays, so we need to reshape them

    """
    
    D, I = idx.search(data, k)
    D, I = D.reshape(-1, k), I.reshape(-1, k)
    
    return D, I


def compute_p1(G, I):
    """ Compute precision-at-1 for groundtruth `G` and predicted indices `I` """
    
    p1 = 0.
    for i, item in enumerate(I):
        p1 += float(int(item) in G[i])

    p1 /= len(G)
    
    return p1


def test_idx(IdxClass, params, xb, xq, G, k=100):
    """ Train and test the given Index class with given params 
    
    Use the provided base, query, and groundtruth vectors. 
    
    Index will predict top-`k` entries.
    
    In case of failure, the exception is returned as string
    
    This function returns a Report dictionary
    
    """
    
    try:

        idx = IdxClass(**params)

        with timer() as train_t:
            idx.train(xb)
            idx.add(xb)

        with timer() as search_t:
            _, I = search(idx, xq, k)

        p1 = compute_p1(G, I[:, 0])

        report = make_report(IdxClass, params, p1, train_t.elapsed, search_t.elapsed)
    
    except Exception as e:
        print('FAILED: ' + str(e))
        report = str(e)

    return report


def now():
    """ Helper function to format current timestamp """
    
    return datetime.datetime.fromtimestamp(time.time()).strftime("%d-%m-%y %H:%M:%S")


def make_report(IdxClass, params, p1, train_t, search_t):
    """ Create a Report dictionary for given set of parameters """
    
    return {
        'ID': uuid.uuid4().hex,
        'algo': IdxClass.__name__,
        'params': params,
        'p1': p1,
        'train_t': train_t,
        'search_t': search_t
    }


def add_result(r):
    """ Add result `r` to the database """
    
    if isinstance(r, dict):
        algo, params, p1, t = r['algo'], r['params'], r['p1'], r['search_t']
        rep = f'(params={params}, p1={p1:.2f}, t={t:.2f})'
    else:
        rep = r
    
    logger.info(f'Adding: {rep}')
    
    def result_adder(doc):
        doc['results'].append(r)
        
    DB.update(result_adder, where('ID') == ID)


def test(IdxClass, **params):
    """ Even higher-level wrapper for testing """
    
    return test_idx(IdxClass, params, xb, xq, G, k=100)

Create the database we'll store the results in


In [ ]:
# !rm ./data/results/ad-hoc-db.json
DB = TinyDB('./data/results/ad-hoc-db.json')
ID = uuid.uuid4().hex

info = dict(
    ID = ID,
    name = 'ad-hoc-results',
    date = now(),
    results = []
)

DB.insert(info)

Load the generated dataset


In [ ]:
xq = load_sift('./data/LSHTC-FT/fvecs.hid.fvecs', dtype=np.float32)
xb = load_sift('./data/LSHTC-FT/fvecs.wo.fvecs', dtype=np.float32)

_n, d, c = xq.shape[0], xq.shape[1], xb.shape[0]

We do not run on all queries, but on random subset of lIMIT queries


In [ ]:
LIMIT = 250_000

inds = np.random.choice(np.arange(_n), LIMIT, replace=False)
xq   = xq[inds, :]

xq = np.copy(np.ascontiguousarray(xq), order='C')
xb = np.copy(np.ascontiguousarray(xb), order='C')

n = xq.shape[0]

Load groundtruth


In [ ]:
G = []
for line in open('./data/LSHTC-FT/fvecs.labels.txt'):
    G.append({int(y) for y in line.split()})
G = [G[idx] for idx in inds]

Logging


In [ ]:
logger.info(f"Loaded dataset of {_n:_}, {d:_}-dimensionsl queries (examples), but limiting to {LIMIT:_} queries")
logger.info(f"The dataset contains {c:_} classes, and more than one class can be positive")

Proxies

These are simple helper classes that will delegate method calls to underlying index, but expose a consistent API to create these Indices


In [ ]:
class IVFIndex:
    def __init__(self, d, size, nprobe):
        self.index = faiss.index_factory(d, f"IVF{size},Flat", faiss.METRIC_INNER_PRODUCT)
        self.index.nprobe = nprobe

    def __getattr__(self, name):
        return getattr(self.index, name)
    
    
class KMeansIndex:
    def __init__(self, d, layers, nprobe, m, U):
        self.aug = mips.MipsAugmentationShrivastava(d, m, U)
        self.index = mips.IndexHierarchicKmeans(d, layers, nprobe, self.aug, False)

    def __getattr__(self, name):
        return getattr(self.index, name)
    
        
class FlatIndex:
    def __init__(self, d):
        self.index = faiss.IndexFlatIP(d)

    def __getattr__(self, name):
        return getattr(self.index, name)
    
class AlshIndex:
    def __init__(self, d: int, L: int, K: int, r: int, m: int, U: int):
        self.aug = mips.MipsAugmentationShrivastava(d, m, U)
        self.index = mips.AlshIndex(d, L, K, r, self.aug)

    def __getattr__(self, name):
        return getattr(self.index, name)
       
class QuantIndex:
    def __init__(self, dim: int, subspace_count: int, centroid_count: int):
        self.index = mips.IndexSubspaceQuantization(dim, subspace_count, centroid_count)
       
    def __getattr__(self, name):
        return getattr(self.index, name)

Round 1

Here we start calling our indices on the data, and log the results.

IVF


In [ ]:
for size in [4096]:
    for nprobe in [1, 16, 32, 64, 128]:
        add_result(
            test(
                IVFIndex, d=d, size=size, nprobe=nprobe))

Kmeans


In [ ]:
for layers in [2]:
    for nprobe in [1, 16, 32, 64, 128]:
        add_result(
            test(
                KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))

Round 2

IVF


In [ ]:
for size in [4096]:
    for nprobe in [256, 512]:
        add_result(
            test(
                IVFIndex, d=d, size=size, nprobe=nprobe))

Kmeans


In [ ]:
for layers in [2]:
    for nprobe in [256, 512]:
        add_result(
            test(
                KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))

In [ ]:
for layers in [3]:
    for nprobe in [1, 16, 32, 64, 128]:
        add_result(
            test(
                KMeansIndex, d=d, layers=layers, nprobe=nprobe, m=5, U=0.85))

Test ALSH and quantization


In [ ]:
for L in [2,4,8,16,32]:
    for K in [2, 4, 8, 16, 32]:
        add_result(
            test(
                AlshIndex, d=d, L=L, K=K, m=5, r=10, U=0.85))

In [ ]:
for subspace in [2,4,8]:
    for centroid in [2, 4, 8, 16, 32]:
        add_result(
            test(
                QuantIndex, dim=d, subspace_count=subspace, centroid_count=centroid))

Baseline

This is a flat index, i.e. full scan over queries, but implemented in an efficient way (FAISS)


In [ ]:
add_result(
    test(
        FlatIndex, d=d))

Full Results


In [ ]:
DB.all()

In [ ]: