Graphical representations of linear models

This notebook is intended to provide examples for how five functions in the seaborn plotting library, regplot, corrplot, lmplot, interactplot, and coefplot, can be used to informatively visualize the relationships between variables in a dataset. The functions are intended to produce plots that are attractive and that can be specified without much work. The goal of these visualizations, which the functions attempt to make achievable, is to emphasize important comparisons in the dataset and provide supporting information without distraction.

These functions are a a bit higher-level than the ones covered in the distributions and timeseries tutorials. Instead of plotting into an existing axis, most expect to have the whole figure to themselves, and they are frequently composed of multiple axes.

All of the functions covered here are Pandas-aware, and both lmplot and coefplot require that the data be in a Pandas dataframe.


In [1]:
import numpy as np
from numpy import mean
from numpy.random import randn
import statsmodels.formula.api as sm
import pandas as pd
from scipy import stats

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

In [2]:
sns.set(palette="Purples_r")
np.random.seed(9221999)
mpl.rc("figure", figsize=(5, 5))

Plotting a simple regression: regplot

First, we will show how to visualize the relationship between two variables. To do so, we'll create a very simple fake data set.


In [3]:
x1 = randn(80)
x2 = randn(80)
x3 = x1 * x2
y1 = .5 + 2 * x1 - x2 + 2.5 * x3 + 3 * randn(80)
y2 = .5 + 2 * x1 - x2 + 2.5 * randn(80)
y3 = y2 + randn(80)
y_logistic = 1 / (1 + np.exp(-y1))
y_flip = [np.random.binomial(1, p) for p in y_logistic]
df = pd.DataFrame(dict(x1=x1, x2=x2, x3=x3, y1=y1, y2=y2, y3=y3, y_flip=y_flip))
d = randn(100, 30)

The regplot function accepts two arrays of data and draws a scatterplot of the relationship. It also fits a regression line to the data, bootstraps the regression to get a confidence interval (95% by default), and plots the marginal distributions of the two variables. It does all this with a very simple function call:


In [4]:
sns.regplot(x2, y2)


The function also works with pandas DataFrame objects, in which case the x and y values should be strings. Calling the function this way gets axis labels for free (otherwise you can use the xlabel and ylabel keyword arguments).

It's also easy to manipulate the appearance of the plot at a high level by setting a global color and altering the confidence interval for the regression estimate. By default, it plots the 95% confidence interval (this is computed with a bootstrap), but you may prefer to use a 68% CI, which corresponds to the standard error of the regression estimate.


In [5]:
sns.regplot("x1", "y2", df, ci=68, color="slategray")


You'll note that the a pearson correlation statistic is automatically computed and displayed in the scatterplot. If your data are not normally distributed, you can provide a different function to calculate a correlation metric; anything that takes two arrays of data and returns a stat numeric or (stat, p) tuple will work.

For large datasets, the bootstrap may become computationally intensive (and confidence intervals with large datasets will dramatically shrink), so you may want to turn it off.


In [6]:
sns.regplot("x3", "y1", df, corr_func=stats.spearmanr, ci=None, color="steelblue")


The default behavior for the fit statistic is to use the function name, but sometimes you might want to use a different string. That's what the func_name keyword argument is for.


In [7]:
r2 = lambda x, y: stats.pearsonr(x, y)[0] ** 2
sns.regplot("x1", "y2", df, corr_func=r2, func_name="$R^2$", color="seagreen")


For finer control over the individual aspects of the plot, you can pass dictionaries with keyword arguments for the underlying seaborn or matplotlib functions.


In [8]:
sns.regplot("x3", "y3", df,
            reg_kws={"lw": 2},
            scatter_kws={"mfc": "none", "ms": 4, "mew": 1.5},
            text_kws={"family": "serif", "size": 11},
            dist_kws={"hist": False, "kde_kws":{"shade": True}})


Pairwise correlations in a large dataset: corrplot

Now let's explore a more complex dataset. We'll use the tips data that is provided with R's reshape2 package. This is a good example dataset in that it provides several quantitative and qualitative variables in a tidy format, but there aren't actually any interesting interactions so I am open to other suggestions for different data sets to use here.


In [9]:
tips = pd.read_csv("https://raw.github.com/mwaskom/seaborn/master/examples/tips.csv")
tips["big_tip"] = tips.tip > (.2 * tips.total_bill)
tips["smoker"] = tips["smoker"] == "Yes"
tips["female"] = tips["sex"] == "Female"
mpl.rc("figure", figsize=(7, 7))

Plotting correlation heatmaps

Once you have a tidy dataset, often the first thing you want is a very high-level summary of the relationships between the variables. Correlation matrix heatmaps can be very useful for this purpose. The corrplot function not only plots a color-coded correlation matrix, but it will also obtain a p value for each correlation using a permutation test to give you some indication of the significance of each relationship while correcting for multiple comparisons in an intelligent way.


In [10]:
sns.corrplot(tips);


Note that if you have a huge dataset, the permutation test will take a while. Of course, if you have a huge dataset, p values will not be particularly relevant, so you can turn off the significance testing.

It's also possible to choose a different colormap, but choose wisely! Don't even try using the "jet" map; you'll get a ValueError. The colorbar itself is also optional.


In [11]:
sns.corrplot(tips, sig_stars=False, cmap="RdBu_r", cbar=False);


By default, the colormap is centered on 0 and uses a diverging map, which is appropriate since 0 is a meaningful boundary and both large positive and negative values are interesting.

Sometimes, though, you are only interested in either positive or negative values. In these cases, you can set the tail for the significance test, which will also change the default colormap to a sequential map.


In [12]:
sns.corrplot(df, sig_tail="upper");


It's also possible to specify the range for the colormap. Note that setting the test direction modifies the colormap and range, but the inverse is not true; the stars here will still correspond to a two-tailed test.


In [13]:
sns.corrplot(df, cmap_range=(-.3, 0));


You might also have many variables, in which case the correlation coefficient annotation may not fit well. In this case, it can be turned off, and the variable names can be moved to the sides of the plot:


In [14]:
f, ax = plt.subplots(1, 1, figsize=(10, 10))
cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                          "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
sns.corrplot(d, annot=False, diag_names=False, cmap=cmap, ax=ax);


Visualizing multiple regression: lmplot

The lmplot function provides a more general interface for plotting linear relationships in a complex set of data. In its most basic usage, it does the same thing as the core of the regplot function. Note that lmplot only works with DataFrames.


In [15]:
mpl.rc("figure", figsize=(5, 5))

In [16]:
sns.lmplot("total_bill", "tip", tips)


The advantage to using lmplot over regplot is that you can visualize linear relationships among subsets of a larger data structure. There are a few ways to do this; but perhaps the most amenable to direct comparisons involves separating subgroups by color.


In [17]:
sns.lmplot("total_bill", "tip", tips, color="time")


The default color palette is husl, but you can use any of the seaborn color palettes for the color factor.


In [18]:
sns.lmplot("total_bill", "tip", tips, color="day", palette="muted", ci=None)


It's not actually neccesary to fit a regression line to the data, if you don't want to. (Although I need to fix things so that the legend shows up when using color grouping -- this doesn't work at the moment).


In [19]:
sns.lmplot("total_bill", "tip", tips, fit_reg=False)


You can also fit higher-order polynomials. Although there is not such a trend in this dataset, let's invent one to see what that might look like.


In [20]:
tips["tip_sqr"] = tips.tip ** 2
sns.lmplot("total_bill", "tip_sqr", tips, order=2)


Logistic Regression

What if we want to fit a model where the response variable is categorical? (At the moment, it must be binary and numeric, so {0, 1} and {True, False} both work).


In [21]:
sns.lmplot("x3", "y_flip", df, ci=68)


This plot suggests there is an relationship between the continuous value of x1 and the likelihood that y_flip will be positive, but it has a few problems. The first is that the individual observations are all plotted on top of each other, so their distribution is obscured. We can address this issue by adding a bit of jitter to the scatter plot.


In [22]:
sns.lmplot("x3", "y_flip", df, ci=68, y_jitter=.05)


A more fundamental problem follows from using basic linear regression with a binary response variable. The regression line implies that the probabilitiy of a positive y_flip with very large values of x1 is larger than 1. Of course, that doesn't make sense, which is why logistic regression was invented. lmplot can likewise plot a logistic curve over the data. You might want to use fewer bootstrap iterations, as the logistic regression fit is much more computationally intensive.


In [23]:
sns.lmplot("x3", "y_flip", df, y_jitter=.05, logistic=True, ci=68, n_boot=1000)


Faceted plots

There are several other ways to visualize fits of the model to sub-groups in the data.

You can also separate out factors into facet plots on the columns or rows.


In [24]:
sns.lmplot("total_bill", "tip", tips, col="sex")


Which doesn't mean you can't keep an association between colors and factors


In [25]:
sns.lmplot("total_bill", "tip", tips, color="sex", col="sex")


By default, the same x and y axes are used for all facets, but you can turn this off if you have a big difference in intercepts that you don't care about.


In [26]:
sns.lmplot("total_bill", "tip", tips, col="sex", sharey=False)


Plotting with discrete predictor variables

Sometimes you will want to plot data where the independent variable is discrete. Although this works fine out of the box:


In [27]:
sns.lmplot("size", "tip", tips)


And can be improved with a bit of jitter:


In [28]:
sns.lmplot("size", "tip", tips, x_jitter=.15)


It might be more informative to estimate the central tendency of each bin. This is easy to do with the x_estimator argument. Just pass any function that aggregates a vector of data into one estimate. The estimator will be bootstrapped and a confidence interval will be plotted -- 95% by default, as in other cases within these functions.


In [29]:
sns.lmplot("size", "tip", tips, x_estimator=mean)


Sometimes you may want to plot binary factors and not extrapolate with the fitted line beyond your data points. (Here the fitted line doesn't make all that much sense for extrapolating within the range of the data either, but it does make the trend more visually obvious). Note that at the moment the independent variable must be "quantitative" (so, numerical or boolean typed), but in the future binary factors with string variables will be implemented.


In [30]:
sns.lmplot("smoker", "size", tips, ci=None, x_estimator=mean, x_ci=68, truncate=True)


With large datasets, the scatterplot can become overwhelming and fail to provide useful information about reliable linear trends. In these cases, it can be useful to bin the predictor variable into discrete values and plot an estimated central tendency and confidence interval.

Note that the regression estimate is still fit to the original data; the binning only applies to the visual representation of the observations.


In [31]:
bins = [10, 20, 30, 40]
sns.lmplot("total_bill", "tip", tips, x_bins=bins)


Other faceting options

You can plot data on both the rows and columns to compare multiple factors at once.


In [32]:
sns.lmplot("total_bill", "tip", tips, row="sex", col="day", size=4)


And, of course, you can compose the color grouping with facets as well to facilitate comparisons within a complicated model structure.


In [33]:
sns.lmplot("total_bill", "tip", tips, col="day", color="sex", size=4)


If you have many of levels for some factor (say, your population of subjects), you may want to "wrap" the levels so that the plot is not too wide:


In [34]:
sns.lmplot("total_bill", "tip", tips, ci=None, col="day", col_wrap=2, color="day", palette="Set1", size=4)


Removing nuisance variables

Because of how we generated our fake dataset, x1 and y3 appear to be related.


In [35]:
sns.lmplot("x1", "y3", df)


However, this relationship is being driven entirely by y2. We can thus ask what happens when we residualize our dependent variable against the confound before plotting.


In [36]:
sns.lmplot("x1", "y3", df, x_partial="y2")


Plotting interactions between continuous variables: interactplot

Faceting and binning the regression by color can make it easy to unpack interactions between the predictor variables in your dataset. However, these methods really only work when, at most, one of your variables is continuous. Two-way interactions between continuous variables are often interesting, though, but difficult to visualize.

Let's make some fake data with an interaction of this sort and explore how we might visualize it.


In [37]:
mpl.rc("figure", figsize=(7, 5.5))

One approach is to bin one of the predictors and then plot the data as before, pretending the predictor is categorical.


In [38]:
bins = np.linspace(-3.5, 3.5, 8)
binned = bins[np.digitize(x2, bins)] - .5
binned[binned < -1] = -1
binned[binned > 1] = 1
df["x2_binned"] = binned

In [39]:
pal = sns.dark_palette("crimson", 3)
sns.lmplot("x1", "y1", df, col="x2_binned", color="x2_binned", palette=pal, ci=None, size=3.5)


This is servicable, but lacking in several ways. It requires several cumbersome steps, the choice of the bin size is arbitrary, and collapsing the continuous data into categories loses information.

An alternative approach plots the two independent variables on the x and y axes of a plot and color-encodes the model predictions with a contour plot. This maintains the continuous nature of the data. The seaborn function interactplot draws such a plot, with an interface similar to regplot:


In [40]:
ax = sns.interactplot(df.x1, df.x2, df.y1);


Naturally, you can directly pass a dataframe, and also adjust the aesthetics of the plot.


In [41]:
sns.interactplot("x1", "x2", "y1", df, cmap="coolwarm", filled=True, levels=25);


The two underlying plot functions are contourf() and plot(), both of which can be tweaked with a keyword argument dictionary.

Note the appearance when we plot data that was not simulated with an interaction.


In [42]:
sns.interactplot("x1", "x2", "y2", df, filled=True,
                 scatter_kws={"color": "dimgray"},
                 contour_kws={"alpha": .5});


This works for logistic regression models, as well.


In [43]:
pal = sns.blend_palette(["#4169E1", "#DFAAEF", "#E16941"], as_cmap=True)
levels = np.linspace(0, 1, 11)
sns.interactplot("x1", "x2", "y_flip", df, levels=levels, cmap=pal, logistic=True);


Plotting linear model parameters: coefplot

Although the above plots can be very helpful for understanding the structure of your data, they fail with more than about 4 variables or with more than one continuous predictor. To visually summarize this kind of model, it can be helpful to plot the point estimates for each coefficient along with confidence intervals. The coefplot function achieves this by using a Patsy formula specification for the model structure.


In [44]:
mpl.rc("figure", figsize=(8, 5))

In [45]:
sns.coefplot("tip ~ day + time * size", tips)



In [46]:
sns.coefplot("total_bill ~ day + time + smoker", tips, ci=68, palette="muted")


When you have repeated measures in your dataset (e.g. an experiment performed with multiple subjects), you can group by the levels of that variable and plot the model coefficients within each group. Note that the semantics of the resulting figure changes a little bit from the example above.


In [47]:
sns.coefplot("tip ~ time * sex", tips, "size", intercept=True)