In [1]:
import lime
import sklearn
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.metrics

In [2]:
from sklearn.datasets import fetch_20newsgroups
categories = ['rec.autos', 'rec.motorcycles']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
class_names = ['cars', 'bikes']


Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)

In [3]:
vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=True)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)

In [4]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=500)
rf.fit(train_vectors, newsgroups_train.target)


Out[4]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)

In [6]:
pred = rf.predict(test_vectors)
print(sklearn.metrics.f1_score(newsgroups_test.target, pred, average='binary'))
print(sklearn.metrics.accuracy_score(newsgroups_test.target, pred))


0.9523809523809523
0.9534005037783375

Explaining predictions


In [7]:
from lime import lime_text
from sklearn.pipeline import make_pipeline
c = make_pipeline(vectorizer, rf)

In [8]:
from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer(class_names=class_names)

In [22]:
idx = 42
exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=10)
print('Document id: %d' % idx)
print('Probability(car) =', c.predict_proba([newsgroups_test.data[idx]])[0,0])
print('True class: %s' % class_names[newsgroups_test.target[idx]])
print(exp.as_list())


C:\ProgramData\Anaconda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Document id: 42
Probability(car) = 0.217
True class: bikes
[('Dod', 0.13814370459671274), ('Re', 0.03895770984644524), ('writes', 0.03627443971891233), ('CA', 0.032000360940795), ('Host', 0.026840935935095844), ('com', 0.02580762435823754), ('syl', 0.024978542220913277), ('imat', 0.02398378903786139), ('seahunt', 0.023427049118047792), ('problem', -0.008622724852528888)]

In [14]:
%matplotlib inline
fig = exp.as_pyplot_figure()



In [15]:
exp.show_in_notebook(text=False)



In [16]:
exp.show_in_notebook(text=True)



In [ ]: