.. _linear_quantitative:: .. currentmodule:: seaborn

Linear models with quantitative data

Linear models are very common in statistical analysis. They are used to understand how linear combinations of *predictor* (or *independent*) variables relate to a *response* (or *dependent*) variable. Seaborn has several functions for exploratory visualizations that correspond with linear regression. This page will focus on the functions that can be used when the main predictor variable (or variables) are quantitative. They differ from functions focused on :ref:`categorical variables ` in that they fit and plot a representation of the model itself in the form of a regression line.

In [ ]:
%matplotlib inline

In [ ]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

In [ ]:
np.random.seed(sum(map(ord, "linear_quantitative")))
.. _lmplot: Visualizing multiple regression with :func:`lmplot` --------------------------------------------------- The :func:`lmplot` function is intended for exploring linear relationships of different forms in multidimensional datasets. Input data must be in a Pandas ``DataFrame`. To plot, provide the predictor and response variable names along with the dataset:

In [ ]:
tips = sns.load_dataset("tips")

In [ ]:
sns.lmplot("total_bill", "tip", tips);
This plot has two main components. The first is a scatterplot, showing the observed datapoints. The second is a regression line, showing the estimated linear model relating the two variables. Because the regression line is only an *estimate*, it is plotted with a 95% confidence band to give an impression of the certainty in the model. You can plot different levels of certainty. For instance, it is common to plot the *standard error* of an estimate, which corresponds to the 68% confidence interval.

In [ ]:
sns.lmplot("total_bill", "tip", tips, ci=68);
Hopefully, the default aesthetics are suitable for exploratory graphics. However, you can also control the aesthetics of the underlying plots separately through dictionaries of keyword arguments for the matplotlib ``scatter`` and ``plot`` functions.

In [ ]:
sns.lmplot("total_bill", "tip", tips,
           scatter_kws={"marker": ".", "color": "slategray"},
           line_kws={"linewidth": 1, "color": "seagreen"});
Plotting with discrete predictor variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Sometimes you will want to plot data where the independent variable is quantitative, but discrete. Although this works fine out of the box:

In [ ]:
sns.lmplot("size", "tip", tips);
And can be improved with a bit of jitter:

In [ ]:
sns.lmplot("size", "tip", tips, x_jitter=.15);
It might be more informative to estimate the central tendency of each bin. This is easy to do with the x_estimator argument. Just pass any function that aggregates a vector of data into one estimate. The estimator will be bootstrapped and a confidence interval will be plotted -- 95% by default, as in other cases within these functions.

In [ ]:
sns.lmplot("size", "tip", tips, x_estimator=np.mean);
It can also be useful to bin continuous predictor variable into discrete values and plot an estimated central tendency and confidence interval. This can be helpful when you have many datapoints and a reliable, but weak, effect. Note that the regression estimate will still fit to the original data; the binning only applies to the visual representation of the observations. With :func:`lmplot`, you can provide specific centroid values for the bins:

In [ ]:
bins = [10, 20, 30, 40]
sns.lmplot("total_bill", "tip", tips, x_bins=bins);
Or, you can give a number of bins and it will find centroids so that they have equal numbers of datapoints in them:

In [ ]:
sns.lmplot("total_bill", "tip", tips, x_bins=8);
Faceted linear model plots ^^^^^^^^^^^^^^^^^^^^^^^^^^
The :func:`lmplot` function is built on top of a :class:`FacetGrid`. That means it's easy to visualize how this relationship changes in different subsets of your dataset. You can read the extended :ref:`documentation ` for more details on how the :class:`FacetGrid` class works. The important thing is that you can supply the names of categorical variables that define subsets of the data to plot in different hues or along the row and columns of a grid of axes. Using a `hue` facet makes it easiest to directly compare the two subsets:

In [ ]:
sns.lmplot("total_bill", "tip", tips, hue="smoker");
To make plots that will better reproduce to black-and-white (i.e. when printed), you may want to let the scatterplot marker vary along with the hue variable.

In [ ]:
sns.lmplot("total_bill", "tip", tips, hue="smoker", markers=["x", "o"]);
Plotting in different columns of a grid also makes a plot that's easy to understand, although direct comparisons between the subsets are more difficult as the data are separated in space. This might be better when you want the viewer to focus on the relationship within each subset independently.

In [ ]:
sns.lmplot("total_bill", "tip", tips, col="smoker");
You can also assign the same variable to multiple roles. This lets you plot the data separately and use color to semantically tag different subsets of the data. Using color in this way can be helpful when you are making multiple different kind of plots, and you want to help connect them visually.

In [ ]:
sns.lmplot("total_bill", "tip", tips, col="smoker", hue="smoker");
:func:`lmplot` accepts all arguments that you would use to initialize a :class:`FacetGrid`, and it returns the grid object after plotting for further tweaking:

In [ ]:
g = sns.lmplot("total_bill", "tip", tips, hue="day", palette="Set2",
               hue_order=["Thur", "Fri", "Sat", "Sun"])
g.set_axis_labels("Total bill (US Dollars)", "Tip");
g.set(xticks=[10, 30, 50], ylim=(0, 10), yticks=[0, 2.5, 5, 7.5, 10]);
Plotting different linear relationships ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ By default, :func:`lmplot` shows the ordinary least squares fit to the data. However, you don't actually need to fit a regression at all:

In [ ]:
sns.lmplot("total_bill", "tip", tips, hue="time", palette="Set1", fit_reg=False);
There are several other choices for how to fit the regression. These options are mutually exclusive. Plotting nonlinear trends ~~~~~~~~~~~~~~~~~~~~~~~~~ You might, for example, wonder if the relationship follows a higher-order trend:

In [ ]:
sns.lmplot("size", "total_bill", tips, order=2);
For an even more flexible fit, you can plot the *lowess* line, which is a nonparametric approach to regression. Note that, currently, lowess plots don't plot a confidence interval around the line.

In [ ]:
sns.lmplot("total_bill", "tip", tips, lowess=True, line_kws={"color": ".2"});
Plotting logistic regression ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Let's define a dependent measure that takes the values ``True`` and ``False``.

In [ ]:
tips["big_tip"] = (tips["tip"] / tips["total_bill"]) > .15
It's possible to plot (and fit) these data as normal:

In [ ]:
sns.lmplot("total_bill", "big_tip", tips);
This plot suggests that diners with a larger total bill are less likely to leave a big tip, but it has a few issues. The first is that the individual observations are all plotted on top of each other, so their distribution is obscured. We can address this issue by adding a bit of jitter to the scatter plot.

In [ ]:
sns.lmplot("total_bill", "big_tip", tips, y_jitter=.05);
A more fundamental problem follows from using basic linear regression with a binary response variable. You might want to read the y axis as the *probability* of a big tip, but this regression line implies a negative probability when the bill is greater than $50. Because that doesn't make sense, :func:`lmplot` can fit a *logistic regression* and plot the resulting curve over the data:

In [ ]:
sns.lmplot("total_bill", "big_tip", tips, y_jitter=.05, logistic=True);
Plotting data with outliers ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sometimes, data will have outliers. Because linear regression is based on the mean, it can be overly influenced by outliers, especially in cases with relatively few datapoints. :func:`lmplot` can also plot the regression model using robust estimation, which reduces the influence of outlying points. Robust estimation is computationally intensive, so you may want to reduce the number of bootstrap iterations when using it.

In [ ]:
sns.lmplot("total_bill", "tip", tips, robust=True, n_boot=500);
Controlling for confounding variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ A major feature of multiple regression is the ability to "control for" confounding variables. Seaborn brings this ability into the visualization realm, allowing you to regress variables out of the predictor and/or response variables before plotting. Let's make a simple three-variable dataset where variables A and B appear related, but are actually independent once you account for variable C.

In [ ]:
c = np.random.randn(50)
a = 2 + c + np.random.randn(50)
b = 3 + c + np.random.randn(50)
df = pd.DataFrame(dict(A=a, B=b, C=c), columns=list("ABC"))
If you plot the relationship between A and B, they will look related:

In [ ]:
sns.lmplot("A", "B", df);
However, if you remove (or "partial out") variable C from variable A before plotting, you can see that they are in fact independent.

In [ ]:
sns.lmplot("A", "B", df, x_partial="C");
.. _regplot: Plotting simple regression with :func:`regplot` ----------------------------------------------- For most exploratory uses, :func:`lmplot` should easily produce an informative graphic. In some cases, though, you may want to draw a scatterplot and regression model in the context of a figure with different subplots. Because the :class:`FacetGrid` initializes its own figure and controls the whole subplot grid, it can't be used here. In fact, :func:`lmplot` is just using a lower-level function, :func:`regplot`, to draw the data and the regression line. :func:`regplot` is lower-level in the sense that it can plot onto an existing Axes and doesn't modify anything about the figure it will be drawn onto. In situations where you want more control, it might be advantageous to use :func:`regplot` directly.

In [ ]:
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
sns.regplot("total_bill", "tip", tips, ax=ax1)
sns.boxplot(tips["tip"], tips["size"], color="Blues_r", ax=ax2).set_ylabel("")
Because :func:`lmplot` is using :func:`regplot` behind the scenes, all of the options for binning the datapoints, fitting the regression, and modifying the aesthetics are shared. What's missing from :func:`regplot` is the faceting options, so each call can only draw a single scatterplot and regression line. The other difference is that :func:`regplot` also accepts data passed directly as numpy arrays or pandas series objects, if you don't have your data organized into a dataframe. You can also directly control the color of the plot:

In [ ]:
x, y = np.random.multivariate_normal([1, 5], [(2, -.8), (-.8, 2)], 80).T
ax = sns.regplot(x, y, color="seagreen")
ax.set(xlabel="x variable", ylabel="y variable");
.. _residplot: Examining model residuals using :func:`residplot` ------------------------------------------------- The validity of linear models are based on a set of assumptions, several of which are about the distribution of the model residuals. Looking at the residuals of your model can also clue you in to interesting higher-order trends that you might otherwise miss. It's easy to visualize the residuals of a simple model in seaborn using the :func:`residplot` function. It looks a lot like :func:`regplot`, but instead of plotting the observations and regression model, it plots the residuals:

In [ ]:
sns.residplot(x, y);
Ideally, the residuals should be scattered about the zero line, without any apparent structure. Let's see what happens when we use data that actually has a higher-order relationship:

In [ ]:
y = x + 1.5 * x ** 2 + np.random.randn(len(x))
sns.residplot(x, y, color="indianred");
Although the structure here is pretty obvious, you can help visualize it by fitting a lowess smoother to the residuals:

In [ ]:
sns.residplot(x, y, color="indianred", lowess=True);
:func:`residplot` lets you calculate the residuals with a polynomial regression. Let's see if fitting a second-order regression removes this structure:

In [ ]:
sns.residplot(x, y, color="indianred", order=2, lowess=True);
.. _jointplot_reg: Plotting marginal distributions using :func:`jointplot` ------------------------------------------------------- When you are interested in a single relationship, you may want to use :func:`jointplot` to draw a regression plot along with the marginal distributions of the two variables. By default, :func:`jointplot` just draws a scatterplot:

In [ ]:
sns.jointplot("total_bill", "tip", tips);
However, if you set ``kind="reg"``, it will fit and bootstrap a regression:

In [ ]:
sns.jointplot("total_bill", "tip", tips, kind="reg", color="seagreen");
There is also a mode for plotting model residuals with :func:`residplot`. This also fits a gaussian distribution to the marginal distribution on ``y`` to help you judge whether the residuals are normally distributed around 0:

In [ ]:
sns.jointplot("total_bill", "tip", tips, kind="resid", color="#774499");
The :func:`jointplot` function is using a :class:`JointGrid` to draw this plot. See the :ref:`docs ` to see more information about the options to this function and how you can use :class:`JointGrid` directly for more flexibility.
.. _interactplot: Describing continuous interactions with :func:`interactplot` ------------------------------------------------------------ Hopefully, :func:`lmplot` is useful for a variety of visualization problems. However, one limitation is that it can plot, at most, a single continuous predictor variable. Although two-way interactions between continuous variables are often interesting, they can be difficult to visualize and understand. Let's make some fake data with a two-way interaction to show how you might approach this problem.

In [ ]:
rs = np.random.RandomState(1)
x1, x2 = rs.normal(size=(2, 100))
y1 = .5 * x1 + 2 * x2 + 2 * x1 * x2 + rs.normal(size=100)
y2 = x1 + x2 + rs.normal(size=100)
p = 1 / (1 + np.exp(-y1))
y_flip = [rs.binomial(1, p_i) for p_i in p]
df = pd.DataFrame(dict(x1=x1, x2=x2, y1=y1, y2=y2, y_flip=y_flip))
One approach is to bin one of the predictors and then plot the data as before, treating the predictor as categorical.

In [ ]:
bins = np.array([-1, 0, 1])
binned = np.digitize(x2, bins)
df["x2_bin"] = binned
sns.lmplot("x1", "y1", df, col="x2_bin", hue="x2_bin", palette="PuBuGn_d", size=3);
This is serviceable, but lacking in several ways. It requires several cumbersome steps, the choice of the bin size is arbitrary, and collapsing the continuous data into categories loses information. An alternative approach plots the two independent variables on the x and y axes of a plot and color-encodes the model predictions with a contour plot. This maintains the continuous nature of the data. The seaborn function :func:`interactplot` draws such a plot using an interface similar to :func:`regplot`:

In [ ]:
sns.interactplot(x1, x2, y1);
Naturally, you can directly pass a dataframe and adjust the aesthetics of the plot:

In [ ]:
sns.interactplot("x1", "x2", "y1", df, cmap="coolwarm", filled=True, levels=25);
The two underlying plot functions are ``contourf()`` and ``plot()``, both of which can be tweaked with a keyword argument dictionary. Note the appearance when we plot data that was not simulated with an interaction:

In [ ]:
sns.interactplot("x1", "x2", "y2", df, filled=True,
                 scatter_kws={"color": "dimgray"},
                 contour_kws={"alpha": .5});
This works for logistic regression too:

In [ ]:
pal = sns.blend_palette(["#4169E1", "#DFAAEF", "#E16941"], as_cmap=True)
levels = np.linspace(0, 1, 11)
sns.interactplot("x1", "x2", "y_flip", df, levels=levels, cmap=pal, logistic=True);
Plotting many pairwise relationships with :func:`corrplot` ---------------------------------------------------------- Sometimes, you want a birds-eye view of a large dataset to see what kind of relationships might exist in it. The :func:`corrplot` function will calculate a correlation matrix for a dataset and draw a heat map with the correlation values. By default, it also uses a permutation test to get p values for the correlation matrix in a way that corrects for the multiple comparisons.

In [ ]:
titanic = sns.load_dataset("titanic").dropna()
attention = sns.load_dataset("attention")
sns.set_context(rc={"figure.figsize": (8, 8)})

In [ ]:
Note that if you have a huge dataset, the permutation test will take a while. Of course, if you have a huge dataset, p values will not be particularly relevant, so you can turn off the significance testing. It’s also possible to choose a different colormap, but choose wisely! Don’t even try using the ``“jet”`` map; you’ll get a ``ValueError``. The colorbar itself is also optional.

In [ ]:
sns.corrplot(titanic, sig_stars=False, cmap="RdBu_r", cbar=False);
By default, the colormap is centered on 0 and uses a diverging map, which is appropriate since 0 is a meaningful boundary and both large positive and negative values are interesting. Sometimes, though, you are only interested in either positive or negative values. In these cases, you can set the tail for the significance test, which will also change the default colormap to a sequential map.

In [ ]:
sns.corrplot(titanic, sig_tail="upper");
It’s also possible to specify the range for the colormap. Note that setting the test direction modifies the colormap and range, but the inverse is not true; the stars here will still correspond to a two-tailed test.

In [ ]:
sns.corrplot(titanic, cmap_range=(-.3, 0));
You might also have many variables, in which case the correlation coefficient annotation may not fit well. In this case, it can be turned off, and the variable names can be moved to the sides of the plot:

In [ ]:
f, ax = plt.subplots(figsize=(10, 10))
cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                          "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
d = np.random.standard_t(20, (100, 30))
sns.corrplot(d, annot=False, diag_names=False, cmap=cmap)
Plotting model coefficients with :func:`coefplot` ------------------------------------------------- When building a complex linear model, it can be more effective to visualize the model coefficients than to preset them in a table. The :func:`coefplot` function takes a `patsy `_ formula string and fits the model specified by that string using `statsmodels `_. It then plots the model coefficients and confidence intervals:

In [ ]:
sns.coefplot("tip ~ scale(total_bill) + size + time + sex + smoker", tips);
The intercept is ignored by default, but you can add it if it's relevant:

In [ ]:
sns.coefplot("score ~ center(solutions) * attention", attention, intercept=True, palette="Set1");
When you have a dataset with repeated measures, :func:`coefplot` can fit the model separately for each unit and plot the coefficients so they can be easily compared across your sample:

In [ ]:
sns.coefplot("score ~ center(solutions)", attention, "subject", intercept=True);
We hope that this collection of features will help you understand the relationships in your datasets. Please get in touch if you have a visualization problem in this domain that these tools don't seem suited for.