In the previous sections of this tutorial series, you have learned how to train models with DeepChem on a variety of applications. You have also learned about modeling the uncertainty associated with a model. But we have not yet really studied the question of model explainability.
Often times when modeling we are asked the question -- How does the model work? Why should we trust this model? My response as a data scientist is usually "because we have rigorously proved model performance on a holdout testset with splits that are realistic to the real world". Oftentimes that is not enough to convince domain experts.
LIME is a tool which can help with this problem. It uses local perturbations of featurespace to determine feature importance. In this tutorial, you'll learn how to use Lime alongside DeepChem to interpret what it is our models are learning.
So if this tool can work in human understandable ways for images can it work on molecules? In this tutorial you will learn how to use LIME for model interpretability for any of our fixed-length featurization models.
This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.
To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment.
In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')
In [2]:
from deepchem.molnet import load_tox21
# Load Tox21 dataset
n_features = 1024
tox21_tasks, tox21_datasets, transformers = load_tox21(reload=False)
train_dataset, valid_dataset, test_dataset = tox21_datasets
Let's now define a model to work on this dataset. Due to the structure of LIME, for now we can only use a fully connected network model.
In [3]:
import deepchem as dc
n_tasks = len(tox21_tasks)
n_features = train_dataset.get_data_shape()[0]
model = dc.models.MultitaskClassifier(n_tasks, n_features)
Our next goal is to train this model on the Tox21 dataset. Let's train for some 10 epochs so we have a reasonably converged model.
In [4]:
num_epochs = 10
losses = []
for i in range(num_epochs):
loss = model.fit(train_dataset, nb_epoch=1)
print("Epoch %d loss: %f" % (i, loss))
losses.append(loss)
Let's evaluate this model on the training and validation set to get some basic understanding of its accuracy. We'll use the ROC-AUC as our metric of choice.
In [5]:
import numpy as np
metric = dc.metrics.Metric(
dc.metrics.roc_auc_score, np.mean, mode="classification")
print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)
print("Train scores")
print(train_scores)
print("Validation scores")
print(valid_scores)
LIME can work on any problem with a fixed size input vector. It works by computing probability distributions for the individual features and the covariance between the features. We are going to create an explainer for our data.
However, before can go that far, we first need to install lime. Luckily, lime is conveniently available on pip
, so you can install it from within this Jupyter notebook.
In [6]:
!pip install lime
Now that we have lime installed, we want to create an Explainer
object for lime
. This object will take in the training dataset and names for the features. We're using circular fingerprints as our features. We don't have natural names for our features, so we just number them numerically. On the other hand, we do have natural names for our labels. Recall that Tox21 is for toxicity assays; so let's call 0 as 'not toxic' and 1 as 'toxic'.
In [0]:
from lime import lime_tabular
feature_names = ["fp_%s" % x for x in range(1024)]
explainer = lime_tabular.LimeTabularExplainer(train_dataset.X,
feature_names=feature_names,
categorical_features=feature_names,
class_names=['not toxic', 'toxic'],
discretize_continuous=True)
We are going to attempt to explain why the model predicts a molecule to be toxic for NR-AR The specific assay details can be found here
In [0]:
# We need a function which takes a 2d numpy array (samples, features) and returns predictions (samples,)
def eval_model(my_model):
def eval_closure(x):
ds = dc.data.NumpyDataset(x, n_tasks=12)
# The 0th task is NR-AR
predictions = my_model.predict(ds)[:,0]
return predictions
return eval_closure
model_fn = eval_model(model)
Let's now attempt to use this evaluation function on a specific molecule. For ease, let's pick the first molecule in the test set.
In [9]:
# Imaging imports to get pictures in the notebook
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
# We want to investigate a toxic compound
active_id = np.where(test_dataset.y[:,0]==1)[0][0]
print(active_id)
Chem.MolFromSmiles(test_dataset.ids[active_id])
Out[9]:
In [0]:
# this returns an Lime Explainer class
# The explainer contains details for why the model behaved the way it did
exp = explainer.explain_instance(test_dataset.X[active_id], model_fn, num_features=5, top_labels=1)
In [11]:
# If we are in an ipython notebook it can show it to us
exp.show_in_notebook(show_table=True, show_all=False)
This output shows the fragments that the model believes contributed towards toxicity/non-toxicity. We can reverse our the hash function and look at the fragments that activated those fingerprints for this molecule.
In [12]:
def fp_mol(mol, fp_length=1024):
"""
returns: dict of <int:list of string>
dictionary mapping fingerprint index
to list of smile string that activated that fingerprint
"""
d = {}
feat = dc.feat.CircularFingerprint(sparse=True, smiles=True, size=1024)
retval = feat._featurize(mol)
for k, v in retval.items():
index = k % 1024
if index not in d:
d[index] = set()
d[index].add(v['smiles'])
return d
# What fragments activated what fingerprints in our active molecule?
my_fp = fp_mol(Chem.MolFromSmiles(test_dataset.ids[active_id]))
# We can calculate which fragments activate all fingerprint
# indexes throughout our entire training set
all_train_fps = {}
X = train_dataset.X
ids = train_dataset.ids
for i in range(len(X)):
d = fp_mol(Chem.MolFromSmiles(ids[i]))
for k, v in d.items():
if k not in all_train_fps:
all_train_fps[k] = set()
all_train_fps[k].update(v)
In [13]:
# We can visualize which fingerprints our model declared toxic for the
# active molecule we are investigating
Chem.MolFromSmiles(list(my_fp[242])[0])
Out[13]:
We can also see what fragments are missing by investigating the training set. According to our explanation having one of these fragments would make our molecule more likely to be toxic.
In [14]:
Chem.MolFromSmiles(list(all_train_fps[242])[0])
Out[14]:
In [15]:
Chem.MolFromSmiles(list(all_train_fps[242])[2])
Out[15]:
In [16]:
Chem.MolFromSmiles(list(all_train_fps[242])[4])
In [17]:
Chem.MolFromSmiles(list(all_train_fps[242])[1])
In [18]:
Chem.MolFromSmiles(list(all_train_fps[242])[3])
Using LIME on fragment based models can give you intuition over which fragments are contributing to your response variable in a linear fashion.
Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.
The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!