In [1]:
import lime
import sklearn
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.metrics

In [2]:
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

In [3]:
vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=True)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)

In [4]:
from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB(alpha=.01)
nb.fit(train_vectors, newsgroups_train.target)


Out[4]:
MultinomialNB(alpha=0.01, class_prior=None, fit_prior=True)

In [7]:
pred = nb.predict(test_vectors)
print(sklearn.metrics.f1_score(newsgroups_test.target, pred, average='weighted'))
print(sklearn.metrics.accuracy_score(newsgroups_test.target, pred))


0.8337679455301017
0.8352363250132767

In [8]:
from lime import lime_text
from sklearn.pipeline import make_pipeline
c = make_pipeline(vectorizer, nb)
from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer(class_names=newsgroups_train.target_names)

In [19]:
idx = 1340
exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])
print('Document id: %d' % idx)
print('Predicted class =', newsgroups_train.target_names[nb.predict(test_vectors[idx]).reshape(1,-1)[0,0]])
print('True class: %s' % newsgroups_train.target_names[newsgroups_test.target[idx]])
print('EXP for class %s' % newsgroups_train.target_names[0])
print(exp.as_list(label=0))


C:\ProgramData\Anaconda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Document id: 1340
Predicted class = soc.religion.christian
True class: alt.atheism
EXP for class alt.atheism
[('Theism', 0.08675815330898323), ('Rice', 0.08556269237636434), ('owlnet', -0.0767723905721994), ('scri', -0.074665910800009), ('Genocide', 0.07117701439107631), ('certainty', -0.060051439124447575)]

In [20]:
exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, top_labels=2)
print(exp.available_labels())


C:\ProgramData\Anaconda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
[15, 19]

In [21]:
exp.show_in_notebook(text=newsgroups_test.data[idx], labels=(15,))



In [ ]: