This notebook is an example of investigating how different input features (a.k.a. predictors, independent variables) affect neural network predictions, to help understand and explain how a model works.
This sort of model interpretation can be useful for various purposes:
Some examples of similar methods for other types of models include computing feature importance for tree ensembles, or examining coefficients of a linear model. It is worth noting that all of these methods have limitations, especially for more complex types of models such as tree ensembles and neural networks.
We'll use the toy "digits" dataset available in scikit-learn. Obviously, an MLP is not a great choice for a computer vision task, but the simple visual patterns in the digits data make for a decent demonstration.
We'll use muffnn.MLPClassifier.prediction_gradient, which can be used for the "gradient x input" discussed by, e.g., Shrikumar et al., ICML 2016. This method computes the gradient of the output (the prediction) with respect to the input vector (e.g., Simonyan et al., ICLR 2014) and then multiplies by the input vector to discount less-active features. Prediction gradient values are essentially coefficients for a local linear model around each example, and in that way, it's related to LIME, SHAP, and other methods.
Note: Prediction gradients sometimes are referred to as sensitivity analysis.
In [1]:
import base64
import io
import logging
from IPython.display import HTML, display
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import muffnn
from sklearn.datasets import load_digits
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import GridSearchCV
import tensorflow as tf
# Ignore an irrelevant tensorflow deprecation warning triggered below.
logging.getLogger('tensorflow').setLevel(logging.CRITICAL)
In [2]:
def image_html(x, cmap, vmin=None, vmax=None):
"""Plot an image represented by the array x."""
plt.figure(figsize=(1, 1))
img = plt.imshow(np.reshape(x, (8, 8)), interpolation='nearest', vmin=vmin, vmax=vmax)
img.set_cmap(cmap)
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf)
buf.seek(0)
image_bytes = buf.read()
plt.close()
return ('<img style="display:inline" src="data:image/png;base64,%s" />' %
base64.encodebytes(image_bytes).decode('ascii'))
def show_images(X, cmap, vmin=None, vmax=None):
display(HTML(''.join([image_html(x, cmap, vmin=vmin, vmax=vmax) for x in X])))
In [3]:
# Load the data
data = load_digits()
# Normalize X to 0-1 range.
X, y = data.data / 16, data.target
# Define a monochromatic colormap for images with [0, 1] values.
mono_cmap = mpl.colors.LinearSegmentedColormap.from_list('foo', [(1,1,1), (0, 0, 0)])
# Define a blue to gray to red colormap for images with [-1, 1] values.
dual_cmap = mpl.colors.LinearSegmentedColormap.from_list('foo', [(1.0, 0.2, 0.2),(0.5, 0.5, 0.5), (0.2,0.2,1.0)])
# Show a sample.
sample_ind = np.random.RandomState(42).choice(np.arange(X.shape[0]), 30, replace=False)
show_images(X[sample_ind,:], mono_cmap)
In [4]:
# Train and cross-validate a very, very simple model.
params = {
'keep_prob': [1.0],
'hidden_units': [[256]]
}
gs = GridSearchCV(muffnn.MLPClassifier(batch_size=16, n_epochs=50, random_state=42),
params)
gs.fit(X, y)
mlp = gs.best_estimator_
print(mlp)
print("accuracy:", gs.best_score_)
Now we can compute prediction gradients for the data. We'll just use in-sample data for simplicity. In practice, one might want to do this with held-out data, depending on the task.
The output will be a 3-dimensional tensor with the gradient of each class (dimension 1) with respect to each input (dimension 2) for each example (dimension 0).
In [5]:
pg_vals = mlp.prediction_gradient(X)
pg_vals.shape
Out[5]:
Let's try to get an idea of which inputs (pixels) the model is associating with which output classes (digits 0, 1, 2, ... 9).
We'll compare aggregating gradients only to aggregating gradient x input. For each, we'll plot 10 images, one for each class (digit) in increasing order from left to right, with blue colors indicating that a pixel has a positive mean gradient for the given class, red indicating negative, and gray indicating that pixels have little effect on predictions.
First, we'll just average across the examples for each class. We can see some patterns related to the typical shapes of digits (e.g., the circle shape for "0"). However, note how a lot of the pixels around the edges, which are mostly just black in the data, are fairly red or blue.
In [6]:
# Recall from above that axis/dimension 0 is for example number, so we're summing over examples here.
vmax_abs = np.abs(pg_vals.sum(axis=0)).max() # adjust color scale
show_images(pg_vals.sum(axis=0), dual_cmap, vmin=-vmax_abs, vmax=vmax_abs)
Multiplying by the input (the "gradient x input" method), as below, will focus the outputs more on the effects of frequently active features.
Note how red pixels appear for areas that, when absent distinguish the digit from other digits. For example, the image for "6" (4th from right) has red pixels in the upper-right. If one takes a "6" and adds pixels to that area, it would start to look like an "8". Similarly, the image for "0" has red in the center, indicating that the inactivity of those pixels distinguishes "0" from other digits (e.g., "8").
In [7]:
# Multiply the prediction gradients ndarray by the input features,
# using np.newaxis to add an axis for classes to broadcast over.
# pg_vals is of (n_examples, n_classes, n_inputs).
# X is of shape (n_examples, n_inputs).
# Then, take the mean across examples.
gi_vals_mean = (pg_vals * X[:, np.newaxis, :]).mean(axis=0)
# define a sensible color scale, then plot
vabs_max = np.max(np.abs(gi_vals_mean))
show_images(gi_vals_mean, dual_cmap, vmin=-vabs_max, vmax=vabs_max)
In general, distinguishing positive from negative effects seems useful, but one can also take the absolute value.
In [8]:
show_images(np.abs(gi_vals_mean),
mono_cmap, vmin=0, vmax=vabs_max)
One can also average the results across classes, though this doesn't seem particularly useful for the digits case.
In [9]:
# Compute the sum of absolute values across all classes.
# Then, add an axis with np.newaxis since that expects an array of (n_examples, n_inputs).
show_images(np.abs(gi_vals_mean).sum(axis=0)[np.newaxis, :],
mono_cmap)
In [10]:
# Sample some data and compute predictions, probabilities, and gradients for plotting.
x_sample = X[sample_ind, :]
yhat_sample = gs.predict(x_sample)
yhat_proba_sample = gs.predict_proba(x_sample)
pg_sample = mlp.prediction_gradient(x_sample)
# color scale for plots
vmax_abs = np.abs(pg_sample * x_sample[:, np.newaxis, :]).max()
for i, (x, yhat, yprob, pg) in enumerate(zip(x_sample, yhat_sample, yhat_proba_sample, pg_sample)):
print("=" * 30)
print("yhat=%d %s (example %d)" %
(yhat, '[' + ' '.join(['%.3f' % v for v in yprob]) + ']', i))
# Note: np.newaxis is used just because show_images expects a (n_examples, n_inputs) array
# but we're only plotting one image at a time.
show_images(np.reshape(x, (8, 8))[np.newaxis, :], mono_cmap)
show_images(pg[yhat][np.newaxis, :] * x, dual_cmap, vmin=-vmax_abs, vmax=vmax_abs)