shap.dependence_plot
This notebook is designed to demonstrate (and so document) how to use the shap.dependence_plot
function. It uses an XGBoost model trained on the classic UCI adult income dataset (which is classification task to predict if people made over 50k in the 90s).
In [2]:
import xgboost
import shap
# train XGBoost model
X,y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
# compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
A dependence plot is a scatter plot that shows the effect a single feature has on the predictions made by the model. In this example the log-odds of making over 50k increases significantly between age 20 and 40.
In [3]:
# The first argument is the index of the feature we want to plot
# The second argument is the matrix of SHAP values (it is the same shape as the data matrix)
# The third argument is the data matrix (a pandas dataframe or numpy array)
shap.dependence_plot(0, shap_values, X)
In [4]:
# If we pass a numpy array instead of a data frame then we
# need pass the feature names in separately
shap.dependence_plot(0, shap_values, X.values, feature_names=X.columns)
In [5]:
# We can pass a feature name instead of an index
shap.dependence_plot("Age", shap_values, X)
In [6]:
# We can also use the special "rank(i)" systax to specify the i'th most
# important feature to the model. As measured by: np.abs(shap_values).mean(0)
# In this example age is the second most important feature.
shap.dependence_plot("rank(1)", shap_values, X)
In [7]:
# The interaction_index argument can be used to explicitly
# set which feature gets used for coloring
shap.dependence_plot("rank(1)", shap_values, X, interaction_index="Education-Num")
In [8]:
# we can turn off interaction coloring
shap.dependence_plot("Age", shap_values, X, interaction_index=None)
In [9]:
# we can use shap.approximate_interactions to guess which features
# may interact with age
inds = shap.approximate_interactions("Age", shap_values, X)
# make plots colored by each of the top three possible interacting features
for i in range(3):
shap.dependence_plot("Age", shap_values, X, interaction_index=inds[i])
In [10]:
import matplotlib.pyplot as plt
# you can use the cmap parameter to provide your own custom color map
shap.dependence_plot("Age", shap_values, X, cmap=plt.get_cmap("cool"))
In [12]:
# by passing show=False you can prevent shap.dependence_plot from calling
# the matplotlib show() function, and so you can keep customizing the plot
# before eventually calling show yourself
shap.dependence_plot(0, shap_values, X, show=False)
plt.title("Age dependence plot")
plt.ylabel("SHAP value for the 'Age' feature")
# plt.savefig("my_dependence_plot.pdf") # we can save a PDF of the figure if we want
plt.show()
In [13]:
# you can use xmax and xmin with a percentile notation to hide outliers
shap.dependence_plot(0, shap_values, X, xmin="percentile(1)", xmax="percentile(99)")
In [14]:
# transparency can help reveal dense vs. sparse areas of the scatter plot
shap.dependence_plot(0, shap_values, X, alpha=0.1)
In [15]:
# an alternative to transparency is to reduce the dot size
shap.dependence_plot(0, shap_values, X, dot_size=2)
In [16]:
# for categorical (or binned) data adding a small amount of x-jitter makes
# thin columns of dots more readable
shap.dependence_plot(0, shap_values, X, x_jitter=1, dot_size=1)
In [109]:
X_cat = X.copy()
relationship_decoding = {
0: 'Not-in-family',
1: 'Unmarried',
2: 'Other-relative',
3: 'Own-child',
4: 'Husband',
5: 'Wife'
}
X_cat["Relationship"] = X_cat["Relationship"].map(relationship_decoding)
X_cat.head(3)
Out[109]:
In [108]:
# You can use string-valued category features
shap.dependence_plot("Relationship", shap_values, X_cat)
In [107]:
# It is also possible to use string-valued features to plot interaction effect
shap.dependence_plot(0, shap_values, X_cat, interaction_index="Relationship")