In [9]:
%matplotlib inline

Categorization Example [REST API]

An example to illustrate binary categorizaiton with FreeDiscovery


In [11]:
from __future__ import print_function

from time import time, sleep
from multiprocessing import Process
import requests
import pandas as pd

pd.options.display.float_format = '{:,.3f}'.format
pd.options.display.expand_frame_repr = False

dataset_name = "treclegal09_2k_subset"     # see list of available datasets

BASE_URL = "http://localhost:5001/api/v0"  # FreeDiscovery server URL
#BASE_URL = "http://52.38.241.62:5001"

In [12]:
if __name__ == '__main__':

    print(" 0. Load the test dataset")
    url = BASE_URL + '/datasets/{}'.format(dataset_name)
    print(" POST", url)
    res = requests.get(url).json()

    # To use a custom dataset, simply specify the following variables
    data_dir = res['data_dir']
    relevant_files = res['seed_relevant_files']
    non_relevant_files = res['seed_non_relevant_files']
    ground_truth_file = res['ground_truth_file']  # (optional)


    # 1. Feature extraction

    print("\n1.a Load dataset and initalize feature extraction")
    url = BASE_URL + '/feature-extraction'
    print(" POST", url)
    fe_opts = {'data_dir': data_dir,
               'stop_words': 'english', 'chunk_size': 2000, 'n_jobs': -1,
               'use_idf': 1, 'sublinear_tf': 0, 'binary': 0, 'n_features': 50001,
               'analyzer': 'word', 'ngram_range': (1, 1), "norm": "l2"
              }
    res = requests.post(url, json=fe_opts).json()

    dsid = res['id']
    print("   => received {}".format(list(res.keys())))
    print("   => dsid = {}".format(dsid))

    print("\n1.b Start feature extraction (in the background)")

    # Make this call in a background process (there should be a better way of doing it)
    url = BASE_URL+'/feature-extraction/{}'.format(dsid)
    print(" POST", url)
    p = Process(target=requests.post, args=(url,))
    p.start()
    sleep(5.0) # wait a bit for the processing to start

    print('\n1.c Monitor feature extraction progress')
    url = BASE_URL+'/feature-extraction/{}'.format(dsid)
    print(" GET", url)

    t0 = time()
    while True:
        res = requests.get(url)
        if res.status_code == 520:
            p.terminate()
            raise ValueError('Processing did not start')
        elif res.status_code == 200:
            break # processing finished
        data = res.json()
        print('     ... {}k/{}k files processed in {:.1f} min'.format(
                    data['n_samples_processed']//1000, data['n_samples']//1000, (time() - t0)/60.))
        sleep(15.0)

    p.terminate()  # just in case, should not be necessary


    print("\n1.d. check the parameters of the extracted features")
    url = BASE_URL + '/feature-extraction/{}'.format(dsid)
    print(' GET', url)
    res = requests.get(url).json()

    print('\n'.join(['     - {}: {}'.format(key, val) for key, val in res.items() \
                                                      if "filenames" not in key]))


    # 2. Document categorization with ML algorithms

    print("\n2.a. Train the ML categorization model")
    print("       {} relevant, {} non-relevant files".format(
        len(relevant_files), len(non_relevant_files)))
    url = BASE_URL + '/categorization/'
    print(" POST", url)
    print(' Training...')

    res = requests.post(url,
                        json={'relevant_filenames': relevant_files,
                              'non_relevant_filenames': non_relevant_files,
                              'dataset_id': dsid,
                              'method': 'LinearSVC',  # one of "LinearSVC", "LogisticRegression", 'xgboost'
                              'cv': 0                          # Cross Validation
                              }).json()

    mid = res['id']
    print("     => model id = {}".format(mid))
    print('    => Training scores: MAP = {average_precision:.3f}, ROC-AUC = {roc_auc:.3f}'.format(**res))

    print("\n2.b. Check the parameters used in the categorization model")
    url = BASE_URL + '/categorization/{}'.format(mid)
    print(" GET", url)
    res = requests.get(url).json()

    print('\n'.join(['     - {}: {}'.format(key, val) for key, val in res.items() \
                                                      if "filenames" not in key]))

    print("\n2.c Categorize the complete dataset with this model")
    url = BASE_URL + '/categorization/{}/predict'.format(mid)
    print(" GET", url)
    res = requests.get(url).json()
    prediction = res['prediction']

    print("    => Predicting {} relevant and {} non relevant documents".format(
        len(list(filter(lambda x: x>0, prediction))),
        len(list(filter(lambda x: x<0, prediction)))))

    print("\n2.d Test categorization accuracy")
    print("         using {}".format(ground_truth_file))  
    url = BASE_URL + '/categorization/{}/test'.format(mid)
    print("POST", url)
    res = requests.post(url, json={'ground_truth_filename': ground_truth_file}).json()

    print('    => Test scores: MAP = {average_precision:.3f}, ROC-AUC = {roc_auc:.3f}'.format(**res))


    # 3. Document categorization with LSI

    print("\n3.a. Calculate LSI")

    url = BASE_URL + '/lsi/'
    print("POST", url)

    n_components = 100
    res = requests.post(url,
                        json={'n_components': n_components,
                              'dataset_id': dsid
                              }).json()

    lid = res['id']
    print('  => LSI model id = {}'.format(lid))
    print('  => SVD decomposition with {} dimensions explaining {:.2f} % variabilty of the data'.format(
                            n_components, res['explained_variance']*100))
    print("\n3.b. Predict categorization with LSI")

    url = BASE_URL + '/lsi/{}/predict'.format(lid)
    print("POST", url)
    res = requests.post(url,
                        json={'relevant_filenames': relevant_files,
                              'non_relevant_filenames': non_relevant_files
                              }).json()
    prediction = res['prediction']

    print('    => Training scores: MAP = {average_precision:.3f}, ROC-AUC = {roc_auc:.3f}'.format(**res))
    df = pd.DataFrame({key: res[key] for key in res if 'prediction'==key or 'nearest' in key})


    print("\n3.c. Test categorization with LSI")
    url = BASE_URL + '/lsi/{}/test'.format(lid)
    print(" POST", url)

    res = requests.post(url,
                        json={'relevant_filenames': relevant_files,
                              'non_relevant_filenames': non_relevant_files,
                              'ground_truth_filename': ground_truth_file
                              }).json()
    print('    => Test scores: MAP = {average_precision:.3f}, ROC-AUC = {roc_auc:.3f}'.format(**res))

    print('\n', df)


    print("\n4.a Delete the extracted features")
    url = BASE_URL + '/feature-extraction/{}'.format(dsid)
    print(" DELETE", url)


 0. Load the test dataset
 POST http://localhost:5001/api/v0/datasets/treclegal09_2k_subset

1.a Load dataset and initalize feature extraction
 POST http://localhost:5001/api/v0/feature-extraction
   => received ['filenames', 'id']
   => dsid = 3b1e20c376624a7b9b524796125c457a

1.b Start feature extraction (in the background)
 POST http://localhost:5001/api/v0/feature-extraction/3b1e20c376624a7b9b524796125c457a

1.c Monitor feature extraction progress
 GET http://localhost:5001/api/v0/feature-extraction/3b1e20c376624a7b9b524796125c457a

1.d. check the parameters of the extracted features
 GET http://localhost:5001/api/v0/feature-extraction/3b1e20c376624a7b9b524796125c457a
     - binary: False
     - sublinear_tf: False
     - min_df: 0.0
     - n_jobs: -1
     - use_hashing: True
     - use_idf: True
     - max_df: 1.0
     - n_features: 50001
     - n_samples: 2465
     - stop_words: english
     - data_dir: /shared/code/wking_code/freediscovery_shared/treclegal09_2k_subset/data
     - n_samples_processed: 2465
     - analyzer: word
     - chunk_size: 2000
     - norm: l2
     - ngram_range: [1, 1]

2.a. Train the ML categorization model
       5 relevant, 63 non-relevant files
 POST http://localhost:5001/api/v0/categorization/
 Training...
     => model id = d269a5f8cd904c0fb03aba9e1fff7ef5
    => Training scores: MAP = 1.000, ROC-AUC = 1.000

2.b. Check the parameters used in the categorization model
 GET http://localhost:5001/api/v0/categorization/d269a5f8cd904c0fb03aba9e1fff7ef5
     - method: LinearSVC
     - options: {'loss': 'squared_hinge', 'C': 1.0, 'class_weight': None, 'fit_intercept': True, 'dual': True, 'intercept_scaling': 1, 'verbose': 0, 'penalty': 'l2', 'multi_class': 'ovr', 'max_iter': 1000, 'random_state': None, 'tol': 0.0001}

2.c Categorize the complete dataset with this model
 GET http://localhost:5001/api/v0/categorization/d269a5f8cd904c0fb03aba9e1fff7ef5/predict
    => Predicting 11 relevant and 2454 non relevant documents

2.d Test categorization accuracy
         using /shared/code/wking_code/freediscovery_shared/treclegal09_2k_subset/ground_truth_file.txt
POST http://localhost:5001/api/v0/categorization/d269a5f8cd904c0fb03aba9e1fff7ef5/test
    => Test scores: MAP = 1.000, ROC-AUC = 1.000

3.a. Calculate LSI
POST http://localhost:5001/api/v0/lsi/
  => LSI model id = 84a6a188929b44359dae4d5f72ec52f4
  => SVD decomposition with 100 dimensions explaining 48.41 % variabilty of the data

3.b. Predict categorization with LSI
POST http://localhost:5001/api/v0/lsi/84a6a188929b44359dae4d5f72ec52f4/predict
    => Training scores: MAP = 1.000, ROC-AUC = 1.000

3.c. Test categorization with LSI
 POST http://localhost:5001/api/v0/lsi/84a6a188929b44359dae4d5f72ec52f4/test
    => Test scores: MAP = 0.751, ROC-AUC = 0.822

       nearest_nrel_doc  nearest_rel_doc  prediction
0                   29                4      -0.414
1                    9                4      -0.466
2                   36                4      -0.568
3                   26                1       1.000
4                   36                1      -0.600
5                   47                2      -0.502
6                   30                1      -0.664
7                   36                0      -0.421
8                   47                1      -0.449
9                   61                4      -1.000
10                  10                1      -0.607
11                  38                3      -0.367
12                  38                3      -0.356
13                  18                0      -0.841
14                  18                0      -0.996
15                  42                2      -0.334
16                  38                1      -0.375
17                  38                3      -0.246
18                  45                2      -0.212
19                   0                4      -0.520
20                  47                2      -0.354
21                  24                4      -0.895
22                   2                1      -0.169
23                   2                2      -0.171
24                  22                0      -0.541
25                  24                4      -0.569
26                  22                3      -0.267
27                  57                0      -0.324
28                  38                1      -0.360
29                  38                3      -0.239
...                ...              ...         ...
2435                14                2      -0.189
2436                52                0      -0.726
2437                52                0      -0.425
2438                47                0      -0.207
2439                41                0      -0.379
2440                49                0      -0.267
2441                41                1      -0.476
2442                26                0      -0.196
2443                61                2      -0.416
2444                61                2      -0.208
2445                61                2      -0.400
2446                61                2      -0.202
2447                36                0      -0.513
2448                36                0      -0.194
2449                47                1      -0.586
2450                25                1      -0.845
2451                10                1      -1.000
2452                10                1      -0.615
2453                47                1      -0.571
2454                47                1      -0.349
2455                35                1      -0.428
2456                12                1       0.656
2457                42                2      -0.547
2458                49                2      -0.267
2459                34                0      -0.726
2460                39                0      -0.209
2461                39                0      -0.200
2462                36                0      -0.220
2463                39                0      -0.216
2464                34                0      -0.501

[2465 rows x 3 columns]

4.a Delete the extracted features
 DELETE http://localhost:5001/api/v0/feature-extraction/3b1e20c376624a7b9b524796125c457a

In [ ]: