Data visualization is one of, if not the, most important method of communicating data science results. It's analogous to writing: if you can't visualize your results, you'll be hard-pressed to convince anyone else of them. By the end of this lecture, you should be able to
matplotlib
to generate figuresThe Matplotlib package as we know it was originally conceived and designed by John Hunter in 2002, originally built as an IPython plugin to enable Matlab-style plotting.
IPython's creator, Fernando Perez, was at the time finishing his PhD and didn't have time to fully vet John's patch. So John took his fledgling plotting library and ran with it, releasing Matplotlib version 0.1 in 2003 and setting the stage for what would be the most flexible and cross-platform Python plotting library to date.
Matplotlib can run on a wide variety of operating systems and make use of a wide variety of graphical backends. Hence, despite some developers complaining that it can feel bloated and clunky, it easily maintains the largest active user base and team of developers, ensuring it will remain relevant in some sense for quite some time yet.
You've seen snippets of matplotlib in action in several assignments and lectures, but we haven't really formalized it yet. Like NumPy, matplotlib follows some use conventions.
In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
By far, we'll use the plt
object from the second import the most; that contains the main plotting library.
Let's say you're coding a standalone Python application, contained in a file myapp.py
. You'll need to explicitly tell matplotlib to generate a figure and display it, via the show()
command.
Then you can run the code from the command line:
$ python myapp.py
Beware: plt.show()
does a lot of things under-the-hood, including interacting with your operating system's graphical backend.
Matplotlib hides all these details from you, but as a consequence you should be careful to only use plt.show()
once per Python session.
Multiple uses of show()
can lead to unpredictable behavior that depends entirely on what backend is in use, so try your best to avoid it.
Remember back to our first lecture, when you learned how to fire up a Python prompt on the terminal? You can plot in that shell just as you can in a script!
In addition, you can enter "matplotlib mode" by using the %matplotlib
magic command in the IPython shell. You'll notice in the above screenshot that the prompt is hovering below line [6], but no line [7] has emerged. That's because the shell is currently not in matplotlib mode, so it will wait indefinitely until you close the figure on the right.
By contrast, in matplotlib mode, you'll immediately get the next line of the prompt while the figure is still open. You can then edit the properties of the figure dynamically to update the plot. To force an update, you can use the command plt.draw()
.
This is probably the mode you're most familiar with: plotting in a notebook, such as the one you're viewing right now.
Since matplotlib's default is to render its graphics in an external window, for plotting in a notebook you will have to specify otherwise, as it's impossible to do this in a browser. You'll once again make use of the %matplotlib
magic command, this time with the inline
argument added to tell matplotlib to embed the figures into the notebook itself.
In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
x = np.random.random(10)
y = np.random.random(10)
plt.plot(x, y)
Out[2]:
Note that you do NOT need to use plt.show()
! When in "inline" mode, matplotlib will automatically render whatever the "active" figure is as soon as you issue some kind of plotting command.
Sometimes you'll want to save the plots you're making to files for use later, perhaps as part of a presentation to demonstrate to your bosses what you've accomplished.
In this case, you once again won't use the plt.show()
command, but instead substitute in the plt.savefig()
command.
An image file will be created (in this case, fig.png
) on the filesystem with the plot.
Matplotlib is designed to operate nicely with lots of different output formats; PNG was just the example used here.
The output format is inferred from the filename used in savefig()
. You can see all the other formats matplotlib supports with the command
In [3]:
fig = plt.figure()
fig.canvas.get_supported_filetypes()
Out[3]:
Ok, let's dive in with some plotting examples and how-tos!
The most basic kind of plot you can make is the line plot. This kind of plot uses (x, y)
coordinate pairs and implicitly draws lines between them. Here's an example:
In [4]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
x = np.array([4, 5, 6])
y = np.array([9, 4, 7])
plt.plot(x, y)
Out[4]:
Matplotlib sees we've created points at (4, 9), (5, 4), and (6, 7), and it connects each of these in turn with a line, producing the above plot. It also automatically scales the x and y axes of the plot so all the data fit visibly inside.
An important side note: matplotlib is stateful, which means it has some memory of what commands you've issued. So if you want to, say, include multiple different plots on the same figure, all you need to do is issue additional plotting commands.
In [5]:
x1 = np.array([4, 5, 6])
y1 = np.array([9, 4, 7])
plt.plot(x1, y1)
x2 = np.array([1, 2, 4])
y2 = np.array([4, 6, 9])
plt.plot(x2, y2)
Out[5]:
They'll even be plotted in different colors. How nice!
Line plots are nice, but let's say I really want a scatter plot of my data; there's no real concept of a line, but instead I have disparate data points in 2D space that I want to visualize. There's a function for that!
In [6]:
x = np.array([4, 5, 6])
y = np.array([9, 4, 7])
plt.scatter(x, y)
Out[6]:
We use the plt.scatter()
function, which operates pretty much the same way as plt.plot()
, except it puts dots in for each data point without drawing lines between them.
Another very useful plot, especially in scientific circles, is the errorbar plot. This is a lot like the line plot, except each data point comes with an errorbar to quantify uncertainty or variance present in each datum.
In [7]:
# This is a great function that gives me 50 evenly-spaced values from 0 to 10.
x = np.linspace(0, 10, 50)
dy = 0.8 # The error rate.
y = np.sin(x) + dy * np.random.random(50) # Adds a little bit of noise.
plt.errorbar(x, y, yerr = dy)
Out[7]:
You use the yerr
argument of the function plt.errorbar()
in order to specify what your error rate in the y-direction is. There's also an xerr
optional argument, if your error is actually in the x-direction.
What about the histograms we built from the color channels of the images in last week's lectures? We can use matplotlib's hist()
function for this.
In [8]:
x = np.random.normal(size = 100)
_ = plt.hist(x, bins = 20)
plt.hist()
has only 1 required argument: a list of numbers.
However, the optional bins
argument is very useful, as it dictates how many bins you want to use to divide up the data in the required argument. Too many bins and every bar in the histogram will have a count of 1; too few bins and all your data will end up in just a single bar!
Here's too few bins:
In [9]:
_ = plt.hist(x, bins = 2)
And too many:
In [10]:
_ = plt.hist(x, bins = 200)
Picking the number of bins for histograms is an art unto itself that usually requires a lot of trial-and-error, hence the importance of having a good visualization setup!
Another point on histograms, specifically its lone required argument: matplotlib expects a 1D array.
This is important if you're trying to visualize, say, the pixel intensities of an image channel. Images are always either 2D (grayscale) or 3D (color, RGB).
As such, if you feed an image object directly into the hist
method, matplotlib will complain:
In [11]:
import matplotlib.image as mpimg
img = mpimg.imread("Lecture22/image1.png") # Our good friend!
channel = img[:, :, 0] # The "R" channel
_ = plt.hist(channel)
Offhand, I don't know what this is, but it definitely is not the intensity histogram we were hoping for.
Here's the magical way around it: all NumPy arrays (which images objects are!) have a flatten()
method.
This function is dead simple: no matter how many dimensions the NumPy array has, whether it's a grayscale image (2D), a color image (3D), or million-dimensional tensor, it completely flattens the whole thing out into a long 1D list of numbers.
In [12]:
print(channel.shape) # Before
flat = channel.flatten()
print(flat.shape) # After
Then just feed the flattened array into the hist
method:
In [13]:
_ = plt.hist(flat)
The last type of plot we'll discuss here isn't really a "plot" in the sense as the previous ones have been, but it is no less important: showing images!
In [14]:
img = mpimg.imread("Lecture22/image1.png")
plt.imshow(img)
Out[14]:
The plt.imshow()
method takes as input a matrix and renders it as an image. If the matrix is 3D, it considers this to be an image in RGB format (width, height, and 3 color dimensions) and uses that information to determine colors. If the matrix is only 2D, it will consider it to be grayscale.
It doesn't even have be a "true" image. Often you want to look at a matrix that you're building, just to get a "feel" for the structure of it. imshow()
is great for this as well.
In [15]:
matrix = np.random.random((100, 100))
plt.imshow(matrix, cmap = "gray")
Out[15]:
We built a random matrix matrix
, and as you can see it looks exactly like that: in fact, a lot like TV static (coincidence?...). The cmap = "gray"
optional argument specifies the "colormap", of which matplotlib has quite a few, but this explicitly enforces the "gray" colormap, otherwise matplotlib will attempt to predict a color scheme.
You may be thinking at this point: this is all cool, but my inner graphic designer cringed at how a few of these plots looked. Is there any way to make them look, well, "nicer"?
There are, in fact, a couple things we can do to spiff things up a little, starting with how we can annotate the plots in various ways.
You can add text along the axes and the top of the plot to give a little extra information about what, exactly, your plot is visualizing. For this you use the plt.xlabel()
, plt.ylabel()
, and plt.title()
functions.
In [16]:
x = np.linspace(0, 10, 50) # 50 evenly-spaced numbers from 0 to 10
y = np.sin(x) # Compute the sine of each of these numbers.
plt.plot(x, y)
plt.xlabel("x") # This goes on the x-axis.
plt.ylabel("sin(x)") # This goes on the y-axis.
plt.title("Plot of sin(x)") # This goes at the top, as the plot title.
Out[16]:
Going back to the idea of plotting multiple datasets on a single figure, it'd be nice to label them in addition to using colors to distinguish them. Luckily, we have legends we can use, but it takes a coordinated effort to use them effectively. Pay close attention:
In [17]:
x = np.linspace(0, 10, 50) # Evenly-spaced numbers from 0 to 10
y1 = np.sin(x) # Compute the sine of each of these numbers.
y2 = np.cos(x) # Compute the cosine of each number.
plt.plot(x, y1, label = "sin(x)")
plt.plot(x, y2, label = "cos(x)")
plt.legend(loc = 0)
Out[17]:
First, you'll notice that the plt.plot()
call changed a little with the inclusion of an optional argument: label
. This string is the label that will show up in the legend.
Second, you'll also see a call to plt.legend()
. This instructs matplotlib to show the legend on the plot. The loc
argument specifies the location; "0" tells matplotlib to "put the legend in the best possible spot, respecting where the graphics tend to be." This is usually the best option, but if you want to override this behavior and specify a particular location, the numbers 1-9 refer to different specific areas of the plot.
This will really come in handy when you need to make multiple plots that span different datasets, but which you want to compare directly. We've seen how matplotlib scales the axes so the data you're plotting are visible, but if you're plotting the data in entirely separate figures, matplotlib may scale the figures differently. If you need set explicit axis limits:
In [18]:
x = np.linspace(0, 10, 50) # Evenly-spaced numbers from 0 to 10
y = np.sin(x) # Compute the sine of each of these numbers.
plt.plot(x, y)
plt.xlim([-1, 11]) # Range from -1 to 11 on the x-axis.
plt.ylim([-3, 3]) # Range from -3 to 3 on the y-axis.
Out[18]:
This can potentially help center your visualizations, too.
Matplotlib has a default progression of colors it uses in plots--you may have noticed the first data you plot is always blue, followed by green. You're welcome to stick with this, or you can manually override the colors scheme in any plot using the optional argument c
(for color).
In [19]:
x = np.linspace(0, 10, 50) # Evenly-spaced numbers from 0 to 10
y = np.sin(x) # Compute the sine of each of these numbers.
plt.plot(x, y, c = "cyan")
Out[19]:
If you're making scatter plots, it can be especially useful to specify the type of marker in addition to the color you want to use. This can really help differentiate multiple scatter plots that are combined on one figure.
In [20]:
X1 = np.random.normal(loc = [-1, -1], size = (10, 2))
X2 = np.random.normal(loc = [1, 1], size = (10, 2))
plt.scatter(X1[:, 0], X1[:, 1], c = "black", marker = "v")
plt.scatter(X2[:, 0], X2[:, 1], c = "yellow", marker = "o")
Out[20]:
Finally, when you're rendering images, and especially matrices, it can help to have a colorbarthat shows the scale of colors you have in your image plot.
In [21]:
matrix = np.random.normal(size = (100, 100))
plt.imshow(matrix, cmap = "gray")
plt.colorbar()
Out[21]:
The matrix is clearly still random, but the colorbar tells us the values in the picture range from around -3.5 or so to +4, giving us an idea of what's in our data.
The truth is, there is endless freedom in matplotlib to customize the look and feel; you could spend a career digging through the documentation to master the ability to change edge colors, line thickness, and marker transparencies. At least in my opinion, there's a better way.
In [22]:
import seaborn as sns # THIS IS THE KEY TO EVERYTHING
x = np.linspace(0, 10, 50) # Evenly-spaced numbers from 0 to 10
y = np.sin(x) # Compute the sine of each of these numbers.
plt.plot(x, y)
Out[22]:
The seaborn
package is a plotting library in its own right, but first and foremost it effectively serves as a "light reskin" of matplotlib, changing the defaults (sometimes drastically) to be much more aesthetically and practically agreeable.
There will certainly be cases where seaborn doesn't solve your plotting issue, but for the most part, I've found import seaborn
to assuage a lot of my complaints.
Matplotlib has a ton of other functionality we've not touched on, but in case you wanted to look into:
Axes3D
object you can use to create 3D lines, scatter plots, and surfaces.The one thing it doesn't quite do is allow for interactive plots--as in, figures you can embed within HTML that use JavaScript in order to give you the ability to do things like zoom in at certain places, or click on specific points. This can be used with other plotting packages like bokeh (pronounced "bouquet") or mpld3.