This notebook was put together by [Jake Vanderplas](http://www.vanderplas.com) for UW's [Astro 599](http://www.astro.washington.edu/users/vanderplas/Astr599/) course. Source and license info is on [GitHub](https://github.com/jakevdp/2013_fall_ASTR599/).
We've discussed some of the basic interface to matplotlib plots previously. Here we'll go a bit more in-depth, and discuss things like multi-panel plots, inset plots, axes formatting, and much more.
IPython has a pylab
mode and (in version 1.0+) a matplotlib
mode. Whenever you're doing plotting in IPython, you should enable one of these.
Adding the inline
flag will use the appropriate backend to make figures appear inline in the notebook. Otherwise, figures will pop-out in separate windows.
At the command line, type
[~]$ ipython --pylab
or, in version 1.0+,
[~]$ ipython --matplotlib
Within the notebook or IPython command-line, type
In[]: %pylab inline
or
In[]: %matplotlib inline
The difference between pylab
mode and matplotlib
mode is that pylab
includes a bunch of silent imports, including
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
This can be convenient, but is often confusing (for example, it replaces the builtin sum()
function with numpy's sum()
!). For that reason we'll use matplotlib
mode. If your IPython version is older than 1.0, you'll have to switch this to pylab
mode.
In [1]:
# check the IPython version
import IPython
print IPython.__version__
In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
In [3]:
# create a figure using the matlab-like interface
x = np.linspace(0, 10, 1000)
plt.subplot(2, 1, 1) # 2x1 grid, first plot
plt.plot(x, np.sin(x))
plt.title('Trig is easy.')
plt.subplot(2, 1, 2) # 2x1 grid, second plot
plt.plot(x, np.cos(x))
plt.xlabel('x')
Out[3]:
The other interface is an object-oriented interface, where we expliticly pass around references to the plot elements we want to work with:
In [4]:
# create the same figure using the object-oriented interface
fig = plt.figure()
ax1 = fig.add_subplot(2, 1, 1)
ax1.plot(x, np.sin(x))
ax1.set_title("Trig is easy")
ax2 = fig.add_subplot(2, 1, 2)
ax2.plot(x, np.cos(x))
ax2.set_xlabel('x')
Out[4]:
These two interfaces are convenient for different circumstances. I find that for doing quick, simple plots, the scripted interface is often easiest. On the other hand, when I want more sophisticated plots, the object-oriented interface is simpler and more powerful. In fact, the scripted interface has several distinct limitations.
It's good practice to use the object-oriented interface.
That is, you should get in the habit of never using the plt.<command>
when you can reference a specific axes or figure object instead.
There are four main ways to create multi-panel plots in matplotlib. From lowest to highest-level they are (roughly):
fig.add_axes()
: useful for creating inset plots.fig.add_subplot()
: useful for simple multi-panel plots.plt.subplots()
: convenience function to create multiple subplots.plt.GridSpec()
: useful for more involved layouts.
In [5]:
fig = plt.figure()
main_ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
In [6]:
inset_ax = fig.add_axes([0.6, 0.6, 0.25, 0.25])
fig
Out[6]:
In [7]:
main_ax.plot(np.random.rand(100), color='gray')
inset_ax.plot(np.random.rand(20), color='black')
fig
Out[7]:
fig.add_subplot()
If you're trying to create multiple axes in a grid, you might use add_axes()
repeatedly, but calculating the extent for each axes is not trivial. The add_subplot()
method can streamline this.
The arguments are of the form N_vertical, N_horizontal, Plot_number
, and the indices start at 1 (a holdover from Matlab):
In [8]:
fig = plt.figure()
for i in range(1, 7):
ax = fig.add_subplot(2, 3, i)
ax.text(0.45, 0.45, str(i), fontsize=24)
If you desire, you can adjust the spacing using fig.subplots_adjust()
, with units relative to the figure size (i.e. between 0 and 1)
In [9]:
fig.subplots_adjust(left=0.1, right=0.9,
bottom=0.1, top=0.9,
hspace=0.4, wspace=0.4)
fig
Out[9]:
In [10]:
fig, ax = plt.subplots(2, 3)
for i in range(2):
for j in range(3):
ax[i, j].text(0.2, 0.45, str((i, j)), fontsize=20)
In [11]:
print type(ax)
print ax.shape
print ax.dtype
An additional nice piece of this routine is the ability to specify that the subplots have a shared x
or y
axis: this ties together the axis limits and removes redundant tick labels:
In [12]:
fig, ax = plt.subplots(2, 3, sharex=True, sharey=True)
x = np.linspace(0, 10, 1000)
for i in range(2):
for j in range(3):
ax[i, j].plot(x, (j + 1) * np.sin((i + 1) * x))
In [13]:
gs = plt.GridSpec(3, 3) # a 3x3 grid
fig = plt.figure(figsize=(6, 6)) # figure size in inches
fig.add_subplot(gs[:, :])
fig.add_subplot(gs[1, 1])
Out[13]:
In [14]:
gs = plt.GridSpec(3, 3, wspace=0.4, hspace=0.4)
fig = plt.figure(figsize=(6, 6))
fig.add_subplot(gs[1, :2])
fig.add_subplot(gs[0, :2])
fig.add_subplot(gs[2, 0])
fig.add_subplot(gs[:2, 2])
fig.add_subplot(gs[2, 1:])
Out[14]:
Check out the documentation of plt.GridSpec
for information on adjusting the subplot parameters. The keyword arguments are similar to those of plt.subplots_adjust()
.
Let's use a multi-panel plot to show the correlations between three different variables. Here's an example of the type of plot we'd like to create:
(source)
Ignore the upper-right panel for the time-being. The point is that there are three variables this plot compares: period, g-r
color, and r-i
color. By comparing the three panels, we gain some excellent intuition into the correlations between the three variables.
Your task: use the asteroids5000.csv
data from the previous breakout session, and plot the semi-major axis, the ellipticity, and the inclination angle in the same manner. Note that these are, respectively, columns (1, 2, 3) in the file.
(Recall that the asteroids file can be found in the github repository, at notebooks/data/asteroids5000.csv
, and you can load the data using np.genfromtxt
).
This is another important piece of producing publication-quality plots: adjusting your axis ticks and tick labels. These are done through plt.Formatter
and plt.Locator
objects.
Formatter
, as you might imagine, adjust the format of the tick labels.Locator
adjusts the location of the tick labels.Additionally, you should be aware that matplotlib has two classes of ticks: major ticks, and minor ticks.
In [15]:
fig, ax = plt.subplots() # trick to create a single axes
ax.plot(np.random.rand(100))
# set the major formatter for the x-axis
ax.xaxis.set_major_formatter(plt.FormatStrFormatter("x=%i"))
# set the major formatter for the y-axis
ax.yaxis.set_major_formatter(plt.FixedFormatter(['A', 'B', 'C', 'D']))
We see the general form used to set these. We first get the axis
instance (not to be confused with the axes
instance), then call the set_major_formatter
or set_minor_formatter
method, and pass a Formatter
object.
The following formatters are available:
plt.ScalarFormatter
: default -- choose the best formatplt.LogFormatter
: default for log plotsplt.NullFormatter
: no tick labelsplt.FixedFormatter
: a fixed list of tick labelsplt.FuncFormatter
: a function which returns a labelplt.FormatStrFormatter
: label derived from a format stringNote that several of these formatters make the most sense when being used along with a custom locator.
The FuncFormatter
can be especially useful:
In [16]:
fig, ax = plt.subplots() # trick to create a single axes
ax.plot(np.random.rand(100))
def formatfunc(tick_loc, tick_num):
return '{0}:{1}'.format(tick_loc, tick_num)
# set the major formatter for the x-axis
ax.xaxis.set_major_formatter(plt.FuncFormatter(formatfunc))
In [17]:
fig, ax = plt.subplots()
ax.plot(10 * np.random.rand(100))
ax.xaxis.set_major_locator(plt.MultipleLocator(10))
ax.yaxis.set_major_locator(plt.FixedLocator([np.e, 2 * np.pi, 0.8]))
The following Locators are available:
plt.LinearLocator
: default: choose a suitable linear spacingplt.LogLocator
: default for log plots: choose a suitable logarithmic spacing.plt.IndexLocator
: default for index plots (where x = range(len(y))
)plt.NullLocator
: no ticks (this is the default for minor ticks)plt.FixedLocator
: specify fixed tick locationsplt.MultipleLocator
: ticks at a multiple of the given numberFor the first four exercises below, use plot(x, y)
where x
ranges from 0 to 10 and y = sin(x)
.
Create a plot with no ticks and no labels
Create a plot with ticks but no labels
Create a plot where the x labels are in multiples of pi
Redo #3, but use a FuncFormatter
to make the labels appear as $\pi$, $2\pi$, $3\pi$, etc. Note that latex math expressions are denoted by dollar signs, but you'll usually need raw strings: e.g. r'$\pi$'
renders to $\pi$
Log scaling can be turned on by using, e.g.
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, xscale='log')
ax.plot(...)
Make a log-log plot of the function $y = x^2$ where x ranges from 1 to 10.
Repeat plot #5, but change the axes so that the major ticks are powers of 2 rather than powers of 10. Turn the minor ticks off.
In [17]:
Matplotlib is extremely customizable. You can change the default of many different components using the matplotlibrc
file, located at ~/.matplotlib/matplotlibrc
. If the file doesn't exist, then default settings are used. You can see a template matplotlibrc file here:
In [18]:
#%load http://matplotlib.org/_static/matplotlibrc
Additionally, if you'd just like to change parameters for the current python session, you can use the plt.rc
function to dynamically set them.
The plt.rc
function has the following signature:
rc(element_name, attr1=val1, attr2=val2)
and the results will be stored in the plt.rcParams
dictionary:
In [19]:
print plt.rcParams['lines.linewidth']
In [20]:
plt.rc('lines', linewidth=5.0)
print plt.rcParams['lines.linewidth']
In [21]:
plt.plot(np.random.rand(10))
Out[21]:
You can examine all the possible rcParams values by printing the keys of the rcParams dictionary.
In [22]:
plt.rcParams.keys()[:5]
Out[22]:
Here's a little fun for all the design-minded folks out there. One of the common complaints about matplotlib is that its default settings look a bit hokey when compared to graphics produced bny tools like ggplot in the R language.
Here, you'll get a chance to try out some customizations of your RC settings to try to match a ggplot plot. Take a look at this image, copied from this blog post:
Your task is to adjust the rcParams
settings so that the following plot command produces something very close to the above plot:
In [31]:
fig = plt.figure()
plt.scatter(np.random.randn(1000),
np.random.randn(1000))
Out[31]:
There are several values you'll probably want to adjust, including but not limited to:
figure.figsize
axes.grid
axes.axisbelow
axes.facecolor
axes.edgecolor
axes.color_cycle
axes.linewidth
xtick.labelsize
grid.alpha
grid.color
grid.linestyle
and probably others... Print rcParams.keys()
to see all the possibilities.
A few other notes
alpha
keyword"red", "blue", "gray",
etc.) and by RGB hex value ("#EAEAEA", "#00FF44", "#C494A5",
etc.) You can see some of these hex color codes at this page.If you finish this and still have some time, then do a Google image search for "ggplot" and find another plot to try to duplicate.
Note also that there are several people presently working on this problem in the matplotlib community: in particular, take a look at the mpltools package to see some examples of trying to streamline the creation of ggplot-style plots in matplotlib.