Tutorial Part 8: Introduction to Model Interpretability

In the previous sections of this tutorial series, you have learned how to train models with DeepChem on a variety of applications. You have also learned about modeling the uncertainty associated with a model. But we have not yet really studied the question of model explainability.

Often times when modeling we are asked the question -- How does the model work? Why should we trust this model? My response as a data scientist is usually "because we have rigorously proved model performance on a holdout testset with splits that are realistic to the real world". Oftentimes that is not enough to convince domain experts.

LIME is a tool which can help with this problem. It uses local perturbations of featurespace to determine feature importance. In this tutorial, you'll learn how to use Lime alongside DeepChem to interpret what it is our models are learning.

So if this tool can work in human understandable ways for images can it work on molecules? In this tutorial you will learn how to use LIME for model interpretability for any of our fixed-length featurization models.

Colab

This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.

Setup

To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment.


In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')


TensorFlow 1.x selected.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  3477  100  3477    0     0  34425      0 --:--:-- --:--:-- --:--:-- 34425
add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
python version: 3.6.9
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing deepchem
done
/usr/local/lib/python3.6/dist-packages/sklearn/externals/joblib/__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=FutureWarning)
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

deepchem-2.3.0 installation finished!
CPU times: user 3.16 s, sys: 757 ms, total: 3.92 s
Wall time: 2min 24s

Making of the Model

The first thing we have to do is train a model. Here we are going to train a toxicity model using Circular fingerprints. The first step will be for us to load up our trusty Tox21 dataset.


In [2]:
from deepchem.molnet import load_tox21

# Load Tox21 dataset
n_features = 1024
tox21_tasks, tox21_datasets, transformers = load_tox21(reload=False)
train_dataset, valid_dataset, test_dataset = tox21_datasets


Loading raw samples now.
shard_size: 8192
About to start loading CSV from /tmp/tox21.csv.gz
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
TIMING: featurizing shard 0 took 33.641 s
TIMING: dataset construction took 33.962 s
Loading dataset from disk.
TIMING: dataset construction took 0.403 s
Loading dataset from disk.
TIMING: dataset construction took 0.203 s
Loading dataset from disk.
TIMING: dataset construction took 0.204 s
Loading dataset from disk.
TIMING: dataset construction took 0.340 s
Loading dataset from disk.
TIMING: dataset construction took 0.048 s
Loading dataset from disk.
TIMING: dataset construction took 0.049 s
Loading dataset from disk.

Let's now define a model to work on this dataset. Due to the structure of LIME, for now we can only use a fully connected network model.


In [3]:
import deepchem as dc

n_tasks = len(tox21_tasks)
n_features = train_dataset.get_data_shape()[0]
model = dc.models.MultitaskClassifier(n_tasks, n_features)


WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Our next goal is to train this model on the Tox21 dataset. Let's train for some 10 epochs so we have a reasonably converged model.


In [4]:
num_epochs = 10
losses = []
for i in range(num_epochs):
 loss = model.fit(train_dataset, nb_epoch=1)
 print("Epoch %d loss: %f" % (i, loss))
 losses.append(loss)


WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:169: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/optimizers.py:76: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:258: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:260: The name tf.variables_initializer is deprecated. Please use tf.compat.v1.variables_initializer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:237: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/losses.py:108: The name tf.losses.softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.softmax_cross_entropy instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/losses.py:109: The name tf.losses.Reduction is deprecated. Please use tf.compat.v1.losses.Reduction instead.

Epoch 0 loss: 0.225362
Epoch 1 loss: 0.146278
Epoch 2 loss: 0.125541
Epoch 3 loss: 0.115947
Epoch 4 loss: 0.112123
Epoch 5 loss: 0.101710
Epoch 6 loss: 0.100300
Epoch 7 loss: 0.101758
Epoch 8 loss: 0.090115
Epoch 9 loss: 0.090089

Let's evaluate this model on the training and validation set to get some basic understanding of its accuracy. We'll use the ROC-AUC as our metric of choice.


In [5]:
import numpy as np

metric = dc.metrics.Metric(
    dc.metrics.roc_auc_score, np.mean, mode="classification")

print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)


Evaluating model
computed_metrics: [0.9911460306475237, 0.9962989723827874, 0.9757023239869564, 0.986256863445856, 0.9259520300246388, 0.9873943742049194, 0.9918725451398143, 0.9379407998794907, 0.9928536256898868, 0.9772374789653557, 0.965923828259603, 0.981542764445936]
computed_metrics: [0.599564636619997, 0.8016699735449735, 0.810107859645929, 0.7260421962379258, 0.6494545454545455, 0.7463417512390842, 0.694942021460713, 0.8004415322107142, 0.7417588886272664, 0.722559331175836, 0.8338163788354211, 0.7412575366063738]
Train scores
{'mean-roc_auc_score': 0.975843469756064}
Validation scores
{'mean-roc_auc_score': 0.7389963876382316}

Using LIME

LIME can work on any problem with a fixed size input vector. It works by computing probability distributions for the individual features and the covariance between the features. We are going to create an explainer for our data.

However, before can go that far, we first need to install lime. Luckily, lime is conveniently available on pip, so you can install it from within this Jupyter notebook.


In [6]:
!pip install lime


Collecting lime
  Downloading https://files.pythonhosted.org/packages/27/ee/4aaac4cd79f16329746495aca96f8c35f278b5c774eff3358eaa21e1cbf3/lime-0.2.0.0.tar.gz (274kB)
     |████████████████████████████████| 276kB 2.8MB/s 
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from lime) (3.2.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from lime) (1.18.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from lime) (1.4.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from lime) (4.41.1)
Collecting pillow==5.4.1
  Downloading https://files.pythonhosted.org/packages/85/5e/e91792f198bbc5a0d7d3055ad552bc4062942d27eaf75c3e2783cf64eae5/Pillow-5.4.1-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)
     |████████████████████████████████| 2.0MB 8.8MB/s 
Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.6/dist-packages (from lime) (0.22.2.post1)
Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.6/dist-packages (from lime) (0.16.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->lime) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->lime) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->lime) (1.2.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->lime) (2.8.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.18->lime) (0.15.1)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.12->lime) (2.4)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.12->lime) (2.4.1)
Requirement already satisfied: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.12->lime) (1.1.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib->lime) (1.12.0)
Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx>=2.0->scikit-image>=0.12->lime) (4.4.2)
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... done
  Created wheel for lime: filename=lime-0.2.0.0-cp36-none-any.whl size=284181 sha256=784faa7c9728629fe2d9ea8f11d74cd9e8e8e3c2346e8e9ecf7dc9f16ce42f0b
  Stored in directory: /root/.cache/pip/wheels/22/f2/ec/e5ebd07348b2b1ac722e91c2f549fcc220f7d5f25497a61232
Successfully built lime
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
Installing collected packages: pillow, lime
  Found existing installation: Pillow 7.0.0
    Uninstalling Pillow-7.0.0:
      Successfully uninstalled Pillow-7.0.0
Successfully installed lime-0.2.0.0 pillow-5.4.1

Now that we have lime installed, we want to create an Explainer object for lime. This object will take in the training dataset and names for the features. We're using circular fingerprints as our features. We don't have natural names for our features, so we just number them numerically. On the other hand, we do have natural names for our labels. Recall that Tox21 is for toxicity assays; so let's call 0 as 'not toxic' and 1 as 'toxic'.


In [0]:
from lime import lime_tabular
feature_names = ["fp_%s"  % x for x in range(1024)]
explainer = lime_tabular.LimeTabularExplainer(train_dataset.X, 
                                              feature_names=feature_names, 
                                              categorical_features=feature_names,
                                              class_names=['not toxic', 'toxic'], 
                                              discretize_continuous=True)

We are going to attempt to explain why the model predicts a molecule to be toxic for NR-AR The specific assay details can be found here


In [0]:
# We need a function which takes a 2d numpy array (samples, features) and returns predictions (samples,)
def eval_model(my_model):
    def eval_closure(x):
        ds = dc.data.NumpyDataset(x, n_tasks=12)
        # The 0th task is NR-AR
        predictions = my_model.predict(ds)[:,0]
        return predictions
    return eval_closure
model_fn = eval_model(model)

Let's now attempt to use this evaluation function on a specific molecule. For ease, let's pick the first molecule in the test set.


In [9]:
# Imaging imports to get pictures in the notebook
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D

# We want to investigate a toxic compound
active_id = np.where(test_dataset.y[:,0]==1)[0][0]
print(active_id)
Chem.MolFromSmiles(test_dataset.ids[active_id])


41
Out[9]:

In [0]:
# this returns an Lime Explainer class
# The explainer contains details for why the model behaved the way it did
exp = explainer.explain_instance(test_dataset.X[active_id], model_fn, num_features=5, top_labels=1)

In [11]:
# If we are in an ipython notebook it can show it to us
exp.show_in_notebook(show_table=True, show_all=False)


This output shows the fragments that the model believes contributed towards toxicity/non-toxicity. We can reverse our the hash function and look at the fragments that activated those fingerprints for this molecule.


In [12]:
def fp_mol(mol, fp_length=1024):
    """
    returns: dict of <int:list of string>
        dictionary mapping fingerprint index
        to list of smile string that activated that fingerprint
    """
    d = {}
    feat = dc.feat.CircularFingerprint(sparse=True, smiles=True, size=1024)
    retval = feat._featurize(mol)
    for k, v in retval.items():
        index = k % 1024
        if index not in d:
            d[index] = set()
        d[index].add(v['smiles'])
    return d
# What fragments activated what fingerprints in our active molecule?
my_fp = fp_mol(Chem.MolFromSmiles(test_dataset.ids[active_id]))

# We can calculate which fragments activate all fingerprint
# indexes throughout our entire training set
all_train_fps = {}
X = train_dataset.X
ids = train_dataset.ids
for i in range(len(X)):
    d = fp_mol(Chem.MolFromSmiles(ids[i]))
    for k, v in d.items():
        if k not in all_train_fps:
            all_train_fps[k] = set()
        all_train_fps[k].update(v)


RDKit WARNING: [02:41:10] WARNING: not removing hydrogen atom without neighbors

In [13]:
# We can visualize which fingerprints our model declared toxic for the
# active molecule we are investigating
Chem.MolFromSmiles(list(my_fp[242])[0])


Out[13]:

We can also see what fragments are missing by investigating the training set. According to our explanation having one of these fragments would make our molecule more likely to be toxic.


In [14]:
Chem.MolFromSmiles(list(all_train_fps[242])[0])


Out[14]:

In [15]:
Chem.MolFromSmiles(list(all_train_fps[242])[2])


Out[15]:

In [16]:
Chem.MolFromSmiles(list(all_train_fps[242])[4])


RDKit ERROR: [02:41:18] non-ring atom 0 marked aromatic

In [17]:
Chem.MolFromSmiles(list(all_train_fps[242])[1])


RDKit ERROR: [02:41:18] non-ring atom 0 marked aromatic

In [18]:
Chem.MolFromSmiles(list(all_train_fps[242])[3])


RDKit ERROR: [02:41:18] non-ring atom 0 marked aromatic

Using LIME on fragment based models can give you intuition over which fragments are contributing to your response variable in a linear fashion.

Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

Star DeepChem on GitHub

This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

Join the DeepChem Gitter

The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!