In [1]:
# Allows us to import packages that exist one level up in the file system
# See https://stackoverflow.com/questions/34478398
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path = [module_path] + sys.path

In [2]:
from tagnews.utils import load_data as ld
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import sklearn.feature_extraction.text
import sklearn.multiclass
import sklearn.linear_model
% matplotlib inline
plt.rcParams['figure.figsize'] = 12, 8

In [3]:
df = ld.load_data()

In [4]:
# TODO: Augment training data with not relevant

In [5]:
crime_df = df.loc[df.loc[:, 'OEMC':'TASR'].any(1), :]
print(crime_df.shape)
crime_df = crime_df.append(df.loc[~df['relevant'], :].sample(n=3000, axis=0))
print(crime_df.shape)

idx = np.random.permutation(crime_df.shape[0])
trn = crime_df.iloc[idx[:int(crime_df.shape[0] * 0.7)], :]
tst = crime_df.iloc[idx[int(crime_df.shape[0] * 0.7):], :]
print(trn.shape)
print(tst.shape)


(40008, 50)
(43008, 50)
(30105, 50)
(12903, 50)

In [6]:
# vectorize data
from nltk import word_tokenize          
from nltk.stem import WordNetLemmatizer 
class LemmaTokenizer(object):
    def __init__(self):
        self.wnl = WordNetLemmatizer()
    def __call__(self, doc):
        return [self.wnl.lemmatize(t) for t in word_tokenize(doc)]

vectorizer = sklearn.feature_extraction.text.CountVectorizer(tokenizer=LemmaTokenizer(),
                                                             binary=True)
X = vectorizer.fit_transform(trn['bodytext'].values)

Y = trn.loc[:, 'OEMC':'TASR'].values

In [7]:
X.shape


Out[7]:
(30105, 225457)

In [8]:
from tagnews.crimetype import benchmark as bt

In [9]:
bench_results = bt.benchmark(
    lambda: sklearn.multiclass.OneVsRestClassifier(
        sklearn.linear_model.LogisticRegression()
    ),
    vectorizer.transform(crime_df['bodytext'].values),
    crime_df.loc[:, 'OEMC':'TASR'].values
)

In [10]:
fpr = pd.DataFrame(bench_results['fpr'], columns=df.loc[:, 'OEMC':'TASR'].columns.values).T

tpr = pd.DataFrame(bench_results['tpr'], columns=df.loc[:, 'OEMC':'TASR'].columns.values).T

ppv = pd.DataFrame(bench_results['ppv'], columns=df.loc[:, 'OEMC':'TASR'].columns.values).T

In [11]:
bench_results


Out[11]:
{'acc': array([ 0.96667205,  0.96641996,  0.96679443,  0.96653744]),
 'clfs': [OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
            intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
            penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
            verbose=0, warm_start=False),
            n_jobs=1),
  OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
            intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
            penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
            verbose=0, warm_start=False),
            n_jobs=1),
  OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
            intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
            penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
            verbose=0, warm_start=False),
            n_jobs=1),
  OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
            intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
            penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
            verbose=0, warm_start=False),
            n_jobs=1)],
 'fpr': array([[  2.79876854e-04,   1.67719898e-01,   1.92074404e-02,
           1.84725395e-02,   4.55883750e-03,   7.98076923e-03,
           1.12222950e-03,   1.58997381e-03,   9.10323455e-03,
           6.04978677e-03,   2.73739853e-03,   5.84011489e-03,
           4.29516969e-02,   2.35693410e-03,   2.53942249e-02,
           2.79981335e-04,   1.60687368e-01,   1.78186252e-03,
           1.21779859e-03,   1.31517144e-03,   5.48100548e-03,
           1.32293415e-02,   4.06740267e-03,   1.77755206e-02,
           2.57265364e-03,   4.80225989e-02,   1.69077588e-03,
           9.34579439e-05,   8.42144662e-04,   9.31272118e-05,
           9.77998066e-02,   1.86828585e-04,   1.87846342e-04,
           5.51257253e-03,   3.35602646e-03,   9.06024096e-03,
           1.98315385e-02,   2.79902967e-04],
        [  1.86532363e-04,   1.72026279e-01,   2.62786951e-02,
           1.75906710e-02,   3.89919163e-03,   8.28755902e-03,
           4.67071462e-04,   8.43881857e-04,   1.01157475e-02,
           6.13679105e-03,   2.73688184e-03,   6.39618138e-03,
           4.54805492e-02,   2.16572505e-03,   3.09152127e-02,
           3.73378139e-04,   1.67318688e-01,   1.97349873e-03,
           1.21791269e-03,   1.22042809e-03,   2.74153904e-03,
           1.15250869e-02,   4.36978054e-03,   1.84223919e-02,
           3.24025541e-03,   4.67505780e-02,   1.50333553e-03,
           5.60224090e-04,   4.67683098e-04,   9.30578820e-05,
           9.95392823e-02,   1.86863496e-04,   4.69748215e-04,
           6.45285563e-03,   4.60034503e-03,   6.27292029e-03,
           2.12517267e-02,   9.33271115e-05],
        [  2.79850746e-04,   1.68966105e-01,   2.28891150e-02,
           2.11037353e-02,   5.51435634e-03,   6.56307306e-03,
           5.61429774e-04,   7.49765698e-04,   7.18516361e-03,
           5.55665807e-03,   1.70084097e-03,   5.17191840e-03,
           3.88101983e-02,   1.69157034e-03,   3.06258322e-02,
           4.66897003e-04,   1.55582305e-01,   2.62812089e-03,
           8.43723634e-04,   1.87828700e-03,   3.69248248e-03,
           1.08846230e-02,   5.53344336e-03,   1.90117934e-02,
           4.18808300e-03,   5.11378164e-02,   2.24866486e-03,
           7.46407912e-04,   7.48152997e-04,   0.00000000e+00,
           9.97102849e-02,   7.47314339e-04,   1.40792191e-03,
           6.26566416e-03,   2.97390637e-03,   6.97066512e-03,
           2.31299735e-02,   9.32227091e-05],
        [  2.79929085e-04,   1.71544578e-01,   2.36363636e-02,
           2.04404068e-02,   4.65912332e-03,   5.68784344e-03,
           7.48083037e-04,   1.12454315e-03,   8.75401226e-03,
           5.65195835e-03,   2.73714016e-03,   5.16548689e-03,
           4.05638665e-02,   2.26137756e-03,   3.33915545e-02,
           5.59597090e-04,   1.72899319e-01,   2.53640207e-03,
           8.43091335e-04,   1.50164242e-03,   5.10155881e-03,
           1.16912712e-02,   4.75820548e-03,   1.84285136e-02,
           3.81351892e-03,   4.87205996e-02,   2.06417714e-03,
           7.47104968e-04,   9.36066648e-04,   0.00000000e+00,
           9.86423165e-02,   6.53838969e-04,   1.59549507e-03,
           7.29436606e-03,   3.84172109e-03,   7.46847721e-03,
           2.11702128e-02,   2.79798545e-04]]),
 'ppv': array([[ 0.72727273,  0.8029695 ,  0.74666667,  0.61885246,  0.72093023,
          0.74772036,  0.625     ,  0.58536585,  0.72910663,  0.9004894 ,
          0.71287129,  0.73245614,  0.91439798,  0.78448276,  0.64872521,
          0.625     ,  0.72916667,  0.63461538,  0.70454545,  0.74074074,
          0.38947368,  0.76071429,  0.88709677,  0.79141836,  0.84210526,
          0.86941341,  0.76      ,  0.75      ,  0.82      ,  0.83333333,
          0.63476298,  0.83333333,  0.96666667,  0.78888889,  0.87889273,
          0.73065903,  0.84602649,  0.76923077],
        [ 0.71428571,  0.80451866,  0.66538462,  0.61220044,  0.76571429,
          0.74705882,  0.7826087 ,  0.775     ,  0.74257426,  0.89527027,
          0.74107143,  0.7112069 ,  0.91325696,  0.74444444,  0.60183968,
          0.55555556,  0.72765124,  0.66666667,  0.675     ,  0.77586207,
          0.5915493 ,  0.79755672,  0.87837838,  0.78904429,  0.79881657,
          0.8722807 ,  0.81395349,  0.14285714,  0.88888889,  0.66666667,
          0.62322166,  0.83333333,  0.93506494,  0.75547445,  0.83275261,
          0.79750779,  0.83207389,  0.93333333],
        [ 0.7       ,  0.80321365,  0.7133758 ,  0.60989011,  0.69148936,
          0.79076923,  0.77777778,  0.76470588,  0.77777778,  0.90834697,
          0.82178218,  0.75115207,  0.92335664,  0.8021978 ,  0.62007624,
          0.44444444,  0.74039669,  0.58208955,  0.775     ,  0.70149254,
          0.51851852,  0.78805395,  0.85606061,  0.78025852,  0.75956284,
          0.86310746,  0.71084337,  0.2       ,  0.77777778,  1.        ,
          0.61900369,  0.46666667,  0.8125    ,  0.77031802,  0.89007092,
          0.79190751,  0.82504013,  0.83333333],
        [ 0.76923077,  0.79560572,  0.69649805,  0.58846918,  0.71511628,
          0.81677019,  0.68      ,  0.71428571,  0.75806452,  0.905     ,
          0.75213675,  0.74285714,  0.92372194,  0.77570093,  0.61492891,
          0.5       ,  0.72301722,  0.59090909,  0.74285714,  0.74603175,
          0.4375    ,  0.78066914,  0.87931034,  0.77221527,  0.79591837,
          0.87066895,  0.72839506,  0.2       ,  0.81132075,  1.        ,
          0.62287552,  0.53333333,  0.79012346,  0.72363636,  0.86531987,
          0.8025641 ,  0.84092726,  0.8125    ]]),
 'tpr': array([[ 0.24242424,  0.81723504,  0.65116279,  0.44216691,  0.55605381,
          0.69886364,  0.33898305,  0.4       ,  0.59389671,  0.82511211,
          0.4556962 ,  0.54397394,  0.89235826,  0.62758621,  0.46450304,
          0.13513514,  0.69694026,  0.37078652,  0.4025974 ,  0.37383178,
          0.21764706,  0.68378812,  0.77464789,  0.73208379,  0.56031128,
          0.84008097,  0.53773585,  0.05769231,  0.63076923,  0.35714286,
          0.56693548,  0.21276596,  0.55238095,  0.51699029,  0.78637771,
          0.67639257,  0.74435543,  0.29411765],
        [ 0.16666667,  0.82427536,  0.63292683,  0.44391785,  0.56540084,
          0.67733333,  0.38297872,  0.35632184,  0.63694268,  0.81664099,
          0.53205128,  0.59566787,  0.89042553,  0.50757576,  0.48159832,
          0.12820513,  0.72357526,  0.37837838,  0.34615385,  0.45      ,
          0.24137931,  0.66521106,  0.71585903,  0.73031284,  0.52123552,
          0.83816588,  0.64220183,  0.02380952,  0.6557377 ,  0.33333333,
          0.54233227,  0.20408163,  0.66666667,  0.56097561,  0.75157233,
          0.65641026,  0.73900075,  0.37837838],
        [ 0.21875   ,  0.82954313,  0.60737527,  0.50531108,  0.55555556,
          0.657289  ,  0.32307692,  0.31707317,  0.57174393,  0.82344214,
          0.49112426,  0.52411576,  0.89409534,  0.65765766,  0.4934277 ,
          0.09302326,  0.71821724,  0.39795918,  0.36470588,  0.45192308,
          0.22105263,  0.63312693,  0.75166297,  0.72489083,  0.56504065,
          0.86075085,  0.74683544,  0.05882353,  0.47457627,  0.16666667,
          0.54376013,  0.14893617,  0.66326531,  0.57671958,  0.7652439 ,
          0.64775414,  0.77467973,  0.2       ],
        [ 0.28571429,  0.81532741,  0.63028169,  0.4736    ,  0.52340426,
          0.6939314 ,  0.29310345,  0.37037037,  0.59872611,  0.81409295,
          0.56050955,  0.52348993,  0.89868421,  0.5971223 ,  0.50932287,
          0.2       ,  0.71883289,  0.36448598,  0.33766234,  0.48453608,
          0.25149701,  0.63732929,  0.78634361,  0.7043379 ,  0.59315589,
          0.84207034,  0.62765957,  0.04545455,  0.62318841,  0.58333333,
          0.55825443,  0.17391304,  0.65979381,  0.5975976 ,  0.75588235,
          0.7081448 ,  0.77810651,  0.43333333]])}

In [12]:
f, axs = plt.subplots(3,1)
tpr.mean(axis=1).plot(kind='bar', ax=axs[0])
axs[0].set_ylabel('TPR')
axs[0].set_xticklabels([])
axs[0].set_ylim([0, 1])
ppv.mean(axis=1).plot(kind='bar', ax=axs[1])
axs[1].set_ylabel('PPV')
axs[1].set_xticklabels([])
axs[1].set_ylim([0, 1])
(1 - fpr).mean(axis=1).plot(kind='bar', ax=axs[2])
axs[2].set_ylabel('1 - FPR')
axs[2].set_ylim([0, 1])
plt.tight_layout()
plt.show()



In [13]:
df.loc[:, 'OEMC'::].columns


Out[13]:
Index(['OEMC', 'CPD', 'SAO', 'CCCC', 'CCJ', 'CCSP', 'CPUB', 'IDOC', 'DOMV',
       'SEXA', 'POLB', 'POLM', 'GUNV', 'GLBTQ', 'JUVE', 'REEN', 'VIOL', 'BEAT',
       'PROB', 'PARL', 'CPLY', 'DRUG', 'CPS', 'GANG', 'ILSP', 'HOMI', 'IPRA',
       'CPBD', 'IMMG', 'ENVI', 'UNSPC', 'ILSC', 'ARSN', 'BURG', 'DUI', 'FRUD',
       'ROBB', 'TASR', 'COPA', 'DIGP'],
      dtype='object')

In [14]:
# this will write 10 files to the notebooks directory
bt.predict_articles(bench_results['clfs'][0], vectorizer, df, n=10)

In [15]:
# sanity check
clf = bench_results['clfs'][0]
pd.DataFrame(
    clf.predict_proba(vectorizer.transform(['marijuana'])),
    columns=df.loc[:, 'OEMC':'TASR'].columns
).T.sort_values(0, ascending=False)


Out[15]:
0
DRUG 0.864677
ENVI 0.228001
CPS 0.227901
UNSPC 0.192679
TASR 0.190943
GLBTQ 0.185953
OEMC 0.181800
VIOL 0.179344
IMMG 0.175331
CPUB 0.175017
DUI 0.174134
CPLY 0.168016
IPRA 0.162643
IDOC 0.161492
REEN 0.161107
PROB 0.159612
PARL 0.156219
ARSN 0.156125
BEAT 0.155040
ILSC 0.148013
CCJ 0.146693
BURG 0.146191
POLM 0.130751
POLB 0.130622
ILSP 0.128354
FRUD 0.125557
GUNV 0.125396
SEXA 0.123399
CPD 0.120370
CCSP 0.109191
SAO 0.104256
CPBD 0.103777
GANG 0.098340
CCCC 0.095742
DOMV 0.086720
HOMI 0.086437
ROBB 0.082713
JUVE 0.073748

In [16]:
not_yet_tagged = df.loc[df['relevant'] & ~df.loc[:, 'OEMC':'TASR'].any(1), :]

In [17]:
not_yet_tagged_preds = pd.DataFrame(
    clf.predict_proba(vectorizer.transform(not_yet_tagged.loc[:, 'bodytext'].values)),
    columns=df.loc[:, 'OEMC':'TASR'].columns
)

In [18]:
f, ax = plt.subplots(1, figsize=(8,8))
percents = not_yet_tagged_preds.max(axis=1).sort_values().reset_index().values[:,1]
filtered = np.linspace(0, 100, not_yet_tagged_preds.shape[0])

ax.plot(percents, filtered)

for filtered_p in [25, 40, 50]:
    p = percents[np.where(filtered >= filtered_p)[0][0]]
    plt.plot([0, p, p],
             [filtered_p, filtered_p, 0],
             '--',
             label='filtered {}% of data with threshold {:.2%}'.format(filtered_p, p))

ax.grid(True)
ax.set_xlim([0, 1])
ax.set_ylim([0, 100])
ax.set_xlabel('Probability threshold $p$', fontsize=16)
ax.set_title('% of data with all tags < probability $p$', fontsize=16)
ax.legend(loc='lower right', fontsize=14)


Out[18]:
<matplotlib.legend.Legend at 0x1a1d21ee10>

In [19]:
import pickle

curr_time = time.strftime('%Y%m%d-%H%M%S')
with open('model-' + curr_time + '.pkl', 'wb') as f:
    pickle.dump(bench_results['clfs'][0], f)
with open('vectorizer-' + curr_time + '.pkl', 'wb') as f:
    pickle.dump(vectorizer, f)


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-19-cc68de12024f> in <module>()
      1 import pickle
      2 
----> 3 curr_time = time.strftime('%Y%m%d-%H%M%S')
      4 with open('model-' + curr_time + '.pkl', 'wb') as f:
      5     pickle.dump(bench_results['clfs'][0], f)

NameError: name 'time' is not defined

In [ ]: