PyCEBox Tutorial



In [1]:
%matplotlib inline

In [2]:
from matplotlib import pyplot as plt
from matplotlib.cm import PuOr
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor

In [3]:
from pycebox.ice import ice, ice_plot

In [4]:
# from random.org, for reproducibility
np.random.seed(400845)

⬛ PyCEBox Tutorial

This tutorial recreates the first example from Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Conditional Expectation using pycebox. For details of pycebox's API, consult the documentation.

First we generate 1,000 data points from the model $X_1, X_2, X_3 \sim U(-1, 1)$, $\varepsilon \sim N(0, 1)$,

$$y = 0.2 X_1 - 5 X_2 + 10 X_2 \cdot \mathbb{I}(X_3 \geq 0) + \varepsilon,$$

where $\mathbb{I}(\cdot)$ is the indicator function.


In [5]:
N = 1000

In [6]:
df = pd.DataFrame(sp.stats.uniform.rvs(-1, 2, size=(N, 3)),
                  columns=['x1', 'x2', 'x3'])
noise = sp.stats.norm.rvs(size=N)

In [7]:
y = 0.2 * df.x1 - 5 * df.x2 + 10 * df.x2 * (df.x3 >= 0) + noise

We will study the relationship between $y$ and $X_2$, which is shown below.


In [8]:
fig, ax = plt.subplots(figsize=(8, 6))

ax.scatter(df.x2, y, c='k', alpha=0.5);

ax.set_xlim(-1.05, 1.05);
ax.set_xlabel('$X_2$');

ax.set_ylabel('$y$');

ax.set_title('Data');



In [9]:
gbm = GradientBoostingRegressor()
gbm.fit(df.values, y)


Out[9]:
GradientBoostingRegressor(alpha=0.9, init=None, learning_rate=0.1, loss='ls',
             max_depth=3, max_features=None, max_leaf_nodes=None,
             min_samples_leaf=1, min_samples_split=2,
             min_weight_fraction_leaf=0.0, n_estimators=100,
             presort='auto', random_state=None, subsample=1.0, verbose=0,
             warm_start=False)

We can now use pycebox's ice function, to generate individual conditional expectation curves with respect to the fitted model and the predictor $X_2$.


In [10]:
ice_df = ice(df, 'x2', gbm.predict, num_grid_points=100)

Each column of ice_df corresponds to a data point, with the $X_2$ value removed. Each row corresponds to an $X_2$ value.


In [11]:
ice_df.head()


Out[11]:
data_x2 -0.999685 -0.997213 -0.996914 -0.994768 -0.992745 -0.991811 -0.990806 -0.990703 -0.989268 -0.986413 ... 0.981406 0.981860 0.985768 0.986167 0.986809 0.987194 0.988980 0.991358 0.992368 0.994746
x1 0.039279 -0.072787 -0.776739 0.741882 -0.524129 0.494542 0.467775 -0.476077 0.773273 0.857048 ... 0.025781 0.213430 0.368466 -0.913941 0.238791 0.773319 -0.710572 -0.055058 -0.068680 0.101857
x3 -0.620973 0.231508 -0.217072 -0.503590 0.574635 -0.630984 -0.058337 0.845764 0.891552 0.173355 ... -0.964609 -0.050217 -0.754750 0.110510 0.281101 0.924249 0.541984 -0.339092 0.730080 0.890106
x2
-0.999685 5.757221 -3.867726 5.767918 5.484838 -3.812199 5.313498 4.695467 -3.843710 -4.624771 -3.352770 ... 6.548622 5.205355 5.651336 -4.487253 -3.801561 -4.848384 -4.346604 5.677944 -4.618255 -4.520552
-0.986017 3.470374 -5.622630 3.495009 4.252989 -5.512691 3.899456 3.447572 -5.544202 -5.735421 -4.517831 ... 4.261775 3.084656 3.954332 -6.422371 -5.556464 -5.959034 -6.112436 3.470374 -6.318748 -6.221044
-0.961970 4.419360 -4.152061 4.443995 4.612133 -4.042123 4.258600 3.806715 -4.073633 -4.854695 -3.637105 ... 5.190631 4.033642 4.293346 -4.836927 -4.085895 -5.078308 -4.641867 4.419360 -4.848179 -4.750476
-0.946674 4.358820 -3.932877 4.383455 4.551592 -3.822938 4.198060 3.746175 -3.854449 -4.635511 -3.417921 ... 5.130091 3.973101 4.232806 -4.617743 -3.866711 -4.859124 -4.422683 4.358820 -4.628995 -4.531291
-0.919890 4.358820 -3.932877 4.383455 4.551592 -3.822938 4.198060 3.746175 -3.854449 -4.635511 -3.417921 ... 5.130091 3.973101 4.232806 -4.617743 -3.866711 -4.859124 -4.422683 4.358820 -4.628995 -4.531291

5 rows × 1000 columns

The individual conditional expectation curves in ice_df can now be plotting using your visualization package of choice. pycebox includes a convenience function, ice_plot for plotting the individual conditional expectation curves using matplotlib.


In [12]:
fig, (data_ax, ice_ax) = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 6))

data_ax.scatter(df.x2, y, c='k', alpha=0.5);

data_ax.set_xlim(-1.05, 1.05);
data_ax.set_xlabel('$X_2$');

data_ax.set_ylabel('$y$');

data_ax.set_title('Data');

ice_plot(ice_df, frac_to_plot=0.1,
         c='k', alpha=0.25,
         ax=ice_ax);

ice_ax.set_xlabel('$X_2$');

ice_ax.set_ylabel('$y$');

ice_ax.set_title('ICE curves');


/opt/conda/lib/python3.5/site-packages/ipykernel/__main__.py:14: PerformanceWarning: indexing past lexsort depth may impact performance.

Inspecting the ICE curves, it seems quite likely that there is an important interaction. Since we know the data generating process, we know that the key interaction is between $X_2$ and $X_3$. The function ice_plot accepts color_by keywords, which accepts either a string, which represents the column in the initial DataFrame or a function to apply to the column index in the ICE DataFrame to use for coloring the ICE plots.


In [13]:
fig, (data_ax, ice_ax) = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 6))

data_ax.scatter(df.x2, y, c='k', alpha=0.5);

data_ax.set_xlim(-1.05, 1.05);
data_ax.set_xlabel('$X_2$');

data_ax.set_ylabel('$y$');

data_ax.set_title('Data');

ice_plot(ice_df, frac_to_plot=0.1,
         color_by='x3', cmap=PuOr,
         ax=ice_ax);

ice_ax.set_xlabel('$X_2$');

ice_ax.set_ylabel('$y$');

ice_ax.set_title('ICE Curves');


This plot makes the interaction between $X_2$ and $X_3$ quite transparent.

Additionally, the ice_plot accepts the plot_points keyword. When plot_points=True, a the predicted value for each data point is plotted on its ICE curve.


In [14]:
fig, (data_ax, ice_ax) = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 6))

data_ax.scatter(df.x2, y, c='k', alpha=0.5);

data_ax.set_xlim(-1.05, 1.05);
data_ax.set_xlabel('$X_2$');

data_ax.set_ylabel('$y$');

data_ax.set_title('Data');

ice_plot(ice_df, frac_to_plot=0.1,
         plot_points=True, point_kwargs={'color': 'k', 'alpha': 0.75},
         color_by='x3', cmap=PuOr,
         ax=ice_ax);

ice_ax.set_xlabel('$X_2$');

ice_ax.set_ylabel('$y$');

ice_ax.set_title('ICE Curves');