Matplotlib & Seaborn: Introduction
In [ ]:
%matplotlib inline
Matplotlib is a Python package used widely throughout the scientific Python community to produce high quality 2D publication graphics. It transparently supports a wide range of output formats including PNG (and other raster formats), PostScript/EPS, PDF and SVG and has interfaces for all of the major desktop GUI (graphical user interface) toolkits. It is a great package with lots of options.
However, matplotlib is...
The 800-pound gorilla — and like most 800-pound gorillas, this one should probably be avoided unless you genuinely need its power, e.g., to make a custom plot or produce a publication-ready graphic.
(As we’ll see, when it comes to statistical visualization, the preferred tack might be: “do as much as you easily can in your convenience layer of choice [nvdr e.g. directly from Pandas, or with seaborn], and then use matplotlib for the rest.”)
(quote used from this blogpost)
And that's we mostly did, just use the .plot
function of Pandas. So, why do we learn matplotlib? Well, for the ...then use matplotlib for the rest.; at some point, somehow!
Matplotlib comes with a convenience sub-package called pyplot
which, for consistency with the wider matplotlib community, should always be imported as plt
:
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
Figure
, axes
and axis
At the heart of every plot is the figure object. The "Figure" object is the top level concept which can be drawn to one of the many output formats, or simply just to screen. Any object which can be drawn in this way is known as an "Artist" in matplotlib.
Lets create our first artist using pyplot, and then show it:
In [ ]:
fig = plt.figure()
plt.show()
On its own, drawing the figure artist is uninteresting and will result in an empty piece of paper (that's why we didn't see anything above).
By far the most useful artist in matplotlib is the "Axes" artist. The Axes artist represents the "data space" of a typical plot, a rectangular axes (the most common, but not always the case, e.g. polar plots) will have 2 (confusingly named) Axis artists with tick labels and tick marks.
There is no limit on the number of Axes artists which can exist on a Figure artist. Let's go ahead and create a figure with a single Axes artist, and show it using pyplot:
In [ ]:
ax = plt.axes()
Matplotlib's pyplot
module makes the process of creating graphics easier by allowing us to skip some of the tedious Artist construction. For example, we did not need to manually create the Figure artist with plt.figure
because it was implicit that we needed a figure when we created the Axes artist.
Under the hood matplotlib still had to create a Figure artist, its just we didn't need to capture it into a variable. We can access the created object with the "state" functions found in pyplot called gcf
and gca
.
Some example data:
In [ ]:
x = np.linspace(0, 5, 10)
y = x ** 2
Observe the following difference:
1. pyplot style: plt... (you will see this a lot for code online!)
In [ ]:
plt.plot(x, y, '-')
2. creating objects
In [ ]:
fig, ax = plt.subplots()
ax.plot(x, y, '-')
Although a little bit more code is involved, the advantage is that we now have full control of where the plot axes are placed, and we can easily add more than one axis to the figure:
In [ ]:
fig, ax1 = plt.subplots()
ax.plot(x, y, '-')
ax2 = fig.add_axes([0.2, 0.5, 0.4, 0.3]) # inset axes
ax1.plot(x, y, '-')
ax1.set_ylabel('y')
ax2.set_xlabel('x')
ax2.plot(x, y*2, 'r-')
In [ ]:
fig, ax = plt.subplots()
ax.plot(x, y, '-')
# ...
In [ ]:
x = np.linspace(-1, 0, 100)
fig, ax = plt.subplots()
# Adjust the created axes so that its topmost extent is 0.8 of the figure.
fig.subplots_adjust(top=0.8)
ax.plot(x, x**2, color='0.4', label="power 2")
ax.plot(x, x**3, color='0.8', linestyle='--', label="power 3")
fig.suptitle('Figure title', fontsize=18,
fontweight='bold')
ax.set_title('Axes title', fontsize=16)
ax.set_xlabel('The X axis')
ax.set_ylabel('The Y axis $y=f(x)$', fontsize=16)
ax.set_xlim(-1.0, 1.1)
ax.set_ylim(-0.1, 1.)
ax.text(0.5, 0.2, 'Text centered at (0.5, 0.2)\nin data coordinates.',
horizontalalignment='center', fontsize=14)
ax.text(0.5, 0.5, 'Text centered at (0.5, 0.5)\nin Figure coordinates.',
horizontalalignment='center', fontsize=14,
transform=ax.transAxes, color='grey')
ax.legend(loc='upper right', frameon=True, ncol=2)
For more information on legend positioning, check this post on stackoverflow!
Another nice blogpost about customizing matplotlib figures: http://pbpython.com/effective-matplotlib.html
The power of the object-oriented way of working makes it possible to change everything. However, mostly we just want quickly a good-looking plot. Matplotlib provides a number of styles that can be used to quickly change a number of settings:
In [ ]:
plt.style.available
In [ ]:
x = np.linspace(0, 10)
with plt.style.context('seaborn-muted'): # 'ggplot', 'bmh', 'grayscale', 'seaborn-whitegrid'
fig, ax = plt.subplots()
ax.plot(x, np.sin(x) + x + np.random.randn(50))
ax.plot(x, np.sin(x) + 0.5 * x + np.random.randn(50))
ax.plot(x, np.sin(x) + 2 * x + np.random.randn(50))
We should not start discussing about colors and styles, just pick your favorite style!
What we have been doing while plotting with Pandas:
In [ ]:
import pandas as pd
In [ ]:
aqdata = pd.read_csv('data/20000101_20161231-NO2.csv', sep=';', skiprows=[1], na_values=['n/d'],
index_col=0, parse_dates=True)
aqdata = aqdata["2014":].resample('D').mean()
In [ ]:
aqdata.plot()
In [ ]:
aqdata.plot(figsize=(16, 6)) # shift tab this!
Making this with matplotlib...
In [ ]:
fig, ax = plt.subplots(figsize=(16, 6))
ax.plot(aqdata.index, aqdata["BASCH"],
aqdata.index, aqdata["BONAP"],
aqdata.index, aqdata["PA18"],
aqdata.index, aqdata["VERS"])
ax.legend(["BASCH", "BONAP", "PA18", "VERS"])
or...
In [ ]:
fig, ax = plt.subplots(figsize=(16, 6))
for station in aqdata.columns:
ax.plot(aqdata.index, aqdata[station], label=station)
ax.legend()
In [ ]:
axs = aqdata.plot(subplots=True, sharex=True,
figsize=(16, 8), colormap='viridis', # Dark2
fontsize=15)
Mimicking this in matplotlib (just as a reference):
In [ ]:
from matplotlib import cm
import matplotlib.dates as mdates
colors = [cm.viridis(x) for x in np.linspace(0.0, 1.0, len(aqdata.columns))] # list comprehension to set up the colors
fig, axs = plt.subplots(4, 1, figsize=(16, 8))
for ax, col, station in zip(axs, colors, aqdata.columns):
ax.plot(aqdata.index, aqdata[station], label=station, color=col)
ax.legend()
if not ax.is_last_row():
ax.xaxis.set_ticklabels([])
ax.xaxis.set_major_locator(mdates.YearLocator())
else:
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax.set_xlabel('Time')
ax.tick_params(labelsize=15)
fig.autofmt_xdate()
In [ ]:
aqdata.columns
In [ ]:
fig, ax = plt.subplots() #prepare a matplotlib figure
aqdata.plot(ax=ax) # use pandas for the plotting
# Provide further adaptations with matplotlib:
ax.set_xlabel("")
ax.tick_params(labelsize=15, pad=8, which='both')
fig.suptitle('Air quality station time series', fontsize=15)
In [ ]:
fig, (ax1, ax2) = plt.subplots(2, 1) #provide with matplotlib 2 axis
aqdata[["BASCH", "BONAP"]].plot(ax=ax1) # plot the two timeseries of the same location on the first plot
aqdata["PA18"].plot(ax=ax2) # plot the other station on the second plot
# further adapt with matplotlib
ax1.set_ylabel("BASCH")
ax2.set_ylabel("PA18")
ax2.legend()
In [ ]:
import seaborn as sns
A scatterplot comparing the three stations with a color variation on the months:
In [ ]:
aqdata["month"] = aqdata.index.month
In [ ]:
sns.pairplot(aqdata["2014"].dropna(),
vars=['BASCH', 'BONAP', 'PA18', 'VERS'],
diag_kind='kde', hue="month")
We will use the Titanic example again:
In [ ]:
titanic = pd.read_csv('data/titanic.csv')
In [ ]:
titanic.head()
Histogram: Getting the univariaite distribution of the Age
In [ ]:
fig, ax = plt.subplots()
sns.distplot(titanic["Age"].dropna(), ax=ax) # Seaborn does not like Nan values...
sns.rugplot(titanic["Age"].dropna(), color="g", ax=ax) # rugplot provides lines at the individual data point locations
ax.set_ylabel("Frequency")
Compare two variables (scatter-plot):
In [ ]:
g = sns.jointplot(x="Fare", y="Age",
data=titanic,
kind="scatter") #kde, hex
In [ ]:
g = sns.jointplot(x="Fare", y="Age",
data=titanic,
kind="scatter") #kde, hex
# Adapt the properties with matplotlib by changing the available axes objects
g.ax_marg_x.set_ylabel("Frequency")
g.ax_joint.set_facecolor('0.1')
g.ax_marg_y.set_xlabel("Frequency")
Who likes regressions?
In [ ]:
fig, ax = plt.subplots()
sns.regplot(x="Fare", y="Age", data=titanic, ax=ax, lowess=False)
# adding the small lines to indicate individual data points
sns.rugplot(titanic["Fare"].dropna(), axis='x',
color="#6699cc", height=0.02, ax=ax)
sns.rugplot(titanic["Age"].dropna(), axis='y',
color="#6699cc", height=0.02, ax=ax)
When you want to take into account a category as well to do regressions, use lmplot
(which is a special case of Facetgrid
):
In [ ]:
sns.lmplot(x="Fare", y="Age", hue="Sex",
data=titanic)
In [ ]:
sns.lmplot(x="Fare", y="Age", hue="Sex",
col="Survived", data=titanic)
Another method to create thes category based split of columns, colors,... based on specific category columns is the factorplot
In [ ]:
titanic.head()
In [ ]:
sns.factorplot(x="Sex",
y="Fare",
col="Pclass",
data=titanic) #kind='strip' # violin,...
In [ ]:
sns.factorplot(x="Sex", y="Fare", col="Pclass", row="Embarked",
data=titanic, kind='bar')
In [ ]:
g = sns.factorplot(x="Survived", y="Fare", hue="Sex",
col="Embarked", data=titanic,
kind="box", size=4, aspect=.5);
g.fig.set_figwidth(15)
g.fig.set_figheight(6)
For more in-depth material:
We only use matplotlib (or matplotlib-based plotting) in this workshop, and it is still the main plotting library for many scientists, but it is not the only existing plotting library.
A nice overview of the landscape of visualisation tools in python was recently given by Jake VanderPlas: (or matplotlib-based plotting): https://speakerdeck.com/jakevdp/pythons-visualization-landscape-pycon-2017
Bokeh (http://bokeh.pydata.org/en/latest/): interactive, web-based visualisation
In [ ]:
from bokeh.io import output_notebook
output_notebook()
In [ ]:
from bokeh.plotting import figure, show
from bokeh.sampledata.iris import flowers
colormap = {'setosa': 'red', 'versicolor': 'green', 'virginica': 'blue'}
colors = [colormap[x] for x in flowers['species']]
p = figure(title = "Iris Morphology")
p.xaxis.axis_label = 'Petal Length'
p.yaxis.axis_label = 'Petal Width'
p.circle(flowers["petal_length"], flowers["petal_width"],
color=colors, fill_alpha=0.2, size=10)
show(p)
Altair (https://altair-viz.github.io/index.html): declarative statistical visualization library for Python, based on Vega.
In [ ]:
from altair import Chart, load_dataset
# load built-in dataset as a pandas DataFrame
iris = load_dataset('iris')
Chart(iris).mark_circle().encode(
x='petalLength',
y='petalWidth',
color='species',
)
This notebook is partly based on material of © 2016, Joris Van den Bossche and Stijn Van Hoey (mailto:jorisvandenbossche@gmail.com, mailto:stijnvanhoey@gmail.com, licensed under CC BY 4.0 Creative Commons and partly on material of the Met Office (Copyright (C) 2013 SciTools, GPL licensed): https://github.com/SciTools/courses
In [ ]: