.. _dataset_exploration: .. currentmodule:: seaborn

Visual dataset exploration

In [ ]:
%matplotlib inline

In [ ]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [ ]:
iris = sns.load_dataset("iris")
flights = sns.load_dataset("flights")
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
.. _heatmap: Visualizing matrices with :func:`heatmap` ----------------------------------------- Often the easiest thing to do to visualize a reasonably large table of data is to encode the value in each cell with a color and plot a heatmap. This can be accomplished with the :func:`heatmap` function. Note that unlike many other seaborn functions, :func:`heatmap` expects the input data to be a table of values with one variable in the rows and one variable in the columns. Your dataset may be in tidy format, as with the ``flights`` example.

In [ ]:
Fortunately, it's easy to pivot a dataframe out into rectangular format. Note that this sorts the index by default, so if your index has an ordering that doesn't correspond to its alphabetical order, you may need to reorder things after the pivot operation. It's also possible to collapse over a third variable (using an aggregation like a mean or sum) with the ``pivot_table`` function.

In [ ]:
flights_rect = flights.pivot("month", "year", "passengers")
flights_rect = flights_rect.ix[flights.month.iloc[:12]]

In [ ]:
Data in this format can be passed to :func:`heatmap` to produce an easily-interpretable visualization.

In [ ]:
If you do particularly care about the precise numeric values, you can annotate each cell.

In [ ]:
sns.heatmap(flights_rect, annot=True, fmt="d");
Because the color is an encoding of the data, it's very important to use an appropriate colormap. (See the :link:`palette tutorial ` for more information about different kinds of color palettes and how to choose one that is appropriate for your data). :func:`heatmap` tries to choose good defaults for your data based off some heuristics about it. For example, if your dataset spans 0, it is assumed that a diverging colormap is more appropriate.

In [ ]:
network_corr = networks.iloc[:, :12].corr()
Note that this might not always be the case! For example, you might be plotting temperature over time in several cities, for which 0 isn't really a meaningful midpoint value. A sequential colormap would be better in that case. When the data is inferred to be diverging, setting the anchor points preserves symmetry around the midpoint (the defaults depend on the extreme values or values near the midpoints if ``robust`` is ``True``).

In [ ]:
sns.heatmap(network_corr, vmax=.8, square=True);
There are other uses for diverging colormaps where the midpoint is not 0. For example, you might want to be plotting change relative to a specific comparison value. To set the midpoint, pass a value to ``center``, which will imply that the colormap should be diverging.

In [ ]:
sns.heatmap(flights_rect, center=flights_rect.loc["January", 1955]);
:func:`heatmap` is an Axes-level function, so you can use it in the context of a more complex figure. Plotting the colorbar is optional, and it can also be drawn in a specific existing Axes.

In [ ]:
f = plt.figure(figsize=(7, 9))
gs = plt.GridSpec(15, 1)
hist_ax = f.add_subplot(gs[:5])

yearly_flights = flights_rect.sum(axis=0)
hist_ax.bar(range(12), yearly_flights, 1, ec="w", lw=2, color=".3")
hist_ax.set(xticks=[], ylabel="flights")

map_ax = f.add_subplot(gs[5:-2])
bar_ax = f.add_subplot(gs[-1])
sns.heatmap(flights_rect, cmap="BuGn", ax=map_ax,
            cbar_ax=bar_ax, cbar_kws={"orientation": "horizontal"})
.. _clustermap: Visualizing clustered matrices with :func:`clustermap` ------------------------------------------------------ Beyond the :func:`heatmap`, you may also be curious how the rows and columns of your rectangular dataset are related to each other. Enter the :func:`clustermap`, which will reorganize the heatmap so that similar entries on the rows and columns are plotted closer together. This can help you discover structure in the dataset. Let's take a look at the flights data.

In [ ]:
cg = sns.clustermap(flights_rect)
By default, both the columns and the rows are clustered. For some datasets, though, you may want to preserve the original ordering:

In [ ]:
cg = sns.clustermap(flights_rect, col_cluster=False)
This is a little skewed because the number of flights increase by year, so let's ``standard_scale`` the data (i.e. divide all the columns by the maximum so we can compare year-to-year on the same scale). We provide ``1`` to indicate that we want to standard scale the columns, but we could also scale the rows, as in the next example.

In [ ]:
cg = sns.clustermap(flights_rect, standard_scale=1)
You could also scale the rows by setting ``standard_scale=0``, to see how the different years cluster together if all the months are normalized across all years.

In [ ]:
cg = sns.clustermap(flights_rect, standard_scale=0)

We could also normalize the rows by their $Z$-score, which subtracts the mean and divides by the standard deviation of each column, thus standardizing them to have 0 mean and a variance of 1. This is helpful for easily seeing which values are greater than the mean, and which are smaller.

In [ ]:
cg = sns.clustermap(flights_rect, z_score=1)
Next, you may want a quick way to look at how different years cluster together. Let's annotate the columns by coloring by increasing year.

In [ ]:
col_colors = sns.color_palette('Greens', n_colors=flights_rect.shape[1])
cg = sns.clustermap(flights_rect, standard_scale=True, col_colors=col_colors)
Now we see that the first few years of the dataset, 1949, 1950, 1951 and 1953 all cluster together on the right, and we can see this easily because the lighter greens come together. We can also color the months by seasons, to see which seasons cluster together.

In [ ]:
col_colors = sns.color_palette('Greens', n_colors=flights_rect.shape[1])

season_colors = {'Winter': sns.color_palette('PuBu', n_colors=3),
                 'Spring': sns.color_palette('YlGn', n_colors=3),
                 'Summer': sns.color_palette('YlOrBr', n_colors=3),
                 'Fall': sns.color_palette('OrRd', n_colors=3)}

month_colors = {'January': season_colors['Winter'][1],
                'February': season_colors['Winter'][2],
                'March': season_colors['Spring'][0],
                'April': season_colors['Spring'][1],
                'May': season_colors['Spring'][2],
                'June': season_colors['Summer'][0],
                'July': season_colors['Summer'][1],
                'August': season_colors['Summer'][2],
                'September': season_colors['Fall'][0],
                'October': season_colors['Fall'][1],
                'November': season_colors['Fall'][2],
                'December': season_colors['Winter'][0]}
row_colors = pd.Series(flights_rect.index).map(month_colors)

cg = sns.clustermap(flights_rect, standard_scale=True, col_colors=col_colors, row_colors=row_colors)
Here, it's easy to see that the summer months all cluster together (along with September). If you like, you can also provide data in long-form, like many other seaborn functions expect. In this case, pass a dictionary of keyword arguments for ``DataFrame.pivot`` and the data will be reshaped behind the scenes.

In [ ]:
pivot_kws = dict(index="month", columns="year", values="passengers")
cg = sns.clustermap(flights, pivot_kws=pivot_kws, standard_scale=1)
The data given to cluster must have no ``NaN``s. However, you can mask the data you're plotting, so the visualization doesn't show the ``NaN``s.

In [ ]:
data2d = np.random.randn(32).reshape(4, 8)
data2d[:2, :2] += 2
mask = data2d > 1
sns.clustermap(data2d, mask=mask)
The function returns an object of type :class:`ClusterGrid`, which exposes some methods that are useful for postprocessing. For example, if you want to save the figure, you should call the ``savefig`` method on this object rather than ``plt.savefig``, or else the dendrograms will be chopped of. You can also do useful things like access the mapping between row and column indices in the clustered matrix and those in the original data.

In [ ]: