Pre-MAP Course Website | Pre-MAP GitHub | Google

Plotting with matplotlib

examples in this notebook are based on Nicholas Hunt-Walker's plotting tutorial and Jake VanderPlas' matplotlib tutorial

In this notebook we will learn how to make basic plots like scatter plots, histograms and line plots in using matplotlib in python.

If you know what you want a plot to look like, but you don't know the code to make it, go to the matplotlib gallery, where you can see examples and see the source code that generated them.

Basic Plot Commands

Some of the basic plotting commands include

plt.plot()         # all purpose plotting function
plt.errorbar()     # plotting with errorbars
plt.loglog(), plt.semilogx(), plt.semilogy()   # plotting in logarithmic space

In [ ]:
# we use matplotlib and specifically pyplot
# the convention is to import it like this:
import matplotlib.pyplot as plt 

# We'll also read some data using astropy, so let's import that
import numpy as np 
from astropy.io import ascii

In [ ]:
# I'm also using this "magic" function to make my plots appear in this notebook
# Only do this when working with notebooks
%matplotlib inline

Let's make some sample x and y data, and plot it with the plt.plot command:


In [ ]:
# Sample data
x = np.arange(10)
y = np.arange(10, 20)

# Make the plot, then show the plot
plt.plot(x, y)
plt.show()

You can customize a bunch of features within the plot:

  • markersize sets the symbol size
  • color sets the color
  • The first argument after y sets the marker shape. Try: x, ., o, +, --, -., :

In [ ]:
plt.plot(x, y, '.', markersize=20, color='red')  
plt.show()

Let's plot a sine function:


In [ ]:
# Let's make x go from zero to 4*pi
x = np.linspace(0, 4*np.pi, 50)
y = np.sin(x)

# This will be a thick dashed line:
plt.plot(x, y, linestyle='--', linewidth=5) 

# Add labels to the axes
plt.xlabel('Xlabel') 
plt.ylabel('Ylable')

# Set the plot title
plt.title('Sine Curve')
plt.show()

Let's plot a figure with errorbars:


In [ ]:
# Let's plot y=x^3
x = np.arange(10)
y = x**3

# Let's make up some errorbars in x and y
xerr_values = 0.2 * np.sqrt(x)
yerr_values = 5 * np.sqrt(y)

# Call the errorbar function 
plt.errorbar(x, y, xerr=xerr_values, yerr=yerr_values)
plt.show()

There are three options for log-plots: plt.loglog(), plt.semilogx(), plt.semilogy()


In [ ]:
x = np.linspace(0, 20)
y = np.exp(x)

plt.semilogy(x, y) 
plt.show()

To add a legend to your plot, include the label argument in the plot command, then call plt.legend() at the end of the plotting commands, before plt.show().


In [ ]:
xred = np.random.rand(100)
yred = np.random.rand(100)

xblue = np.random.rand(20)
yblue = np.random.rand(20)

plt.plot(xred, yred, '^', color='red', markersize=8, 
         label='Red Points')

plt.plot(xblue, yblue, '+', color='blue', markersize=12, 
         markeredgewidth=3, label='Blue Points')

plt.xlabel('Xaxis')
plt.ylabel('Yaxis')

plt.legend()

# You can also place the legend in different places using this: 
# plt.legend(loc='lower left')
plt.show()

To save figures in python you just use plt.savefig():


In [ ]:
x = np.linspace(0, 10)
y = np.sin(x)

plt.plot(x, y)
plt.title('sin')
plt.xlabel('Xaxis')
plt.ylabel('Yaxis')

# just give savefig the file name, or path to file name that you want to write
plt.savefig('sineplot.png') 
plt.show()

Let's make a log-log plot displaying Kepler's law:


In [ ]:
# Semimajor-axis in units of AU:
a_AU = np.array([0.387, 0.723, 1. , 1.524, 5.203, 9.537, 19.191, 30.069, 39.482])

# Orbital period in units of years
T_yr = np.array([0.24, 0.62, 1., 1.88, 11.86, 29.46, 84.01, 164.8, 247.7])

# Let's set gravity and mass of the sun in [cgs] units:
G = 6.67e-8
Msun = 1.99e+33 

plt.loglog(a_AU, T_yr, 'o')
plt.xlabel('Semi-Major Axis [AU]')
plt.ylabel('Period [yrs]')
plt.show()

In [ ]:
# now plot a function over the data
# as you work more in python you will learn how to actually fit models to your data 
def keplers_third_law(a, M):
    return np.sqrt((4*np.pi**2 * a**3) / (G * M))

# Convert semimajor-axis into centimeters
a_cm = a_AU * 1.496e+13 

# Convert period into seconds
T_s = T_yr * 3.154e+7

plt.loglog(a_cm, T_s, 'o')
plt.loglog(a_cm, keplers_third_law(a_cm, Msun), '--', 
           label='Keplers Third Law') # try swapping out Msun with something else and see what it looks like 
plt.xlabel('Semi-Major Axis [cm]')
plt.ylabel('Period [s]')
plt.legend(loc=2)
plt.show()

Scatter Plots


In [ ]:
# first let's read in some data to use for plotting
galaxy_table = ascii.read('data/mygalaxy.dat')
galaxy_table[:5]

In [ ]:
# simple scatter plot
plt.scatter(galaxy_table['col1'], galaxy_table['col2'])
plt.show()

SIDE NOTE: If you are running things in the IPython environment or from a script you would want to do something like the following to get your plots to show up in a new window:

plt.scatter(galaxy_table['col1'], galaxy_table['col2'])
plt.show()

In an IPython Notebook, you will see the plot outputs whether or not you call plt.show() because we've used the %matplotlib inline magic function.

Let's break down these basic examples:

  • We are running functions called "plot" or "scatter" that take specific arguments.
  • The most basic arguments that these functions take are in the form of (x,y) values for the plot, and we get these from a data table.
  • We can use more specific arugments like 'o' to customize things like the plot symbol (marker) that we are using.

With plt.scatter() you can change things like point color, point size, point edge color and point type. The argument syntax for adding these options are as follows:

  • color = 'colorname'; could be 'b' for blue, 'k' for black, 'r' for red
  • s = number; changes marker size
  • markeredgecolor = None or 'colorname'
  • marker = 'symbolname', i.e. 's' for square, 'o' for circle, '+' for cross, 'x' for x, '*' for star, '^' for triangle, etc.

Let's do an example:


In [ ]:
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], 
            color='blue', s=1, edgecolor='None', marker='o')
plt.show()

In [ ]:
# here would be the equivalent statement using plt.plot(), note that the syntax is a little different 
plt.plot(galaxy_table['col1'], galaxy_table['col2'], 'o', 
         color='blue', markersize=1, markeredgecolor='None')
plt.show()

The plot is starting to look better, but there is one really important thing that is missing: axis labels. These are very easy to put in in matplotlib using plt.xlabel() and plt.ylabel(). These functions take strings as their arguments for the labels, but can also take other arguments that case the text format:


In [ ]:
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], color='blue', 
            s=1, edgecolor='None', marker='o')

plt.xlabel('Galactic Longitude (degrees)', 
           fontweight='bold', size=16)
plt.ylabel('Galactic Latitude (degrees)', 
           fontweight='bold', size=16)
plt.show()

We can also change things like the axis limits with plt.xlim() and plt.ylim(). For these we just want to feed it a range of values for each axis:


In [ ]:
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], 
            color='blue', s=1, edgecolor='None', marker='o')

plt.xlabel('Galactic Longitude (degrees)', 
           fontweight='bold', size=16)
plt.ylabel('Galactic Latitude (degrees)', 
           fontweight='bold', size=16)

plt.xlim([-180,180])
plt.ylim([-90,90])
plt.show()

The axis labels are easy to read, but the numbers and tick marks on the axis are pretty small. We can tweak lots of little things about how the tick marks look, how they are spaced, and if we want to have a grid to guide the reader's eyes. I will give just a couple of examples here:


In [ ]:
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], 
            color='blue', s=1, edgecolor='None', marker='o')

# Labels
plt.xlabel('Galactic Longitude (degrees)', 
           fontweight='bold', size=16)
plt.ylabel('Galactic Latitude (degrees)', 
           fontweight='bold', size=16)

# Set limits
plt.xlim([-180,180])
plt.ylim([-90,90])

# Choose axis ticks
plt.xticks(range(-180,210,60), fontsize=16, fontweight='bold') # change tick spacing, font size and bold
plt.yticks(range(-90,120,30), fontsize=16, fontweight='bold')

# turn on minor tick marks 
plt.minorticks_on()

plt.grid() # turn on a background grip to guide the eye 
plt.show()

By default the figure is square, but maybe this is not the best way to represent our data. If this is the case we can change the size of the figure:


In [ ]:
plt.figure(figsize=(10,4)) # change figure size 
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], 
            color='blue', s=1, edgecolor='None', marker='o')

# Labels
plt.xlabel('Galactic Longitude (degrees)', 
           fontweight='bold', size=16)
plt.ylabel('Galactic Latitude (degrees)', 
           fontweight='bold', size=16)

# Set limits
plt.xlim([-180,180])
plt.ylim([-90,90])

# Choose axis ticks
plt.xticks(range(-180,210,60), fontsize=16, fontweight='bold') # change tick spacing, font size and bold
plt.yticks(range(-90,120,30), fontsize=16, fontweight='bold')

# turn on minor tick marks 
plt.minorticks_on()

plt.grid() # turn on a background grip to guide the eye 
plt.show()

The last thing I'll mention here is how to put text on plots. This too is simple as long as you specify (x,y) coordinates for the text.


In [ ]:
plt.figure(figsize=(10,4)) # change figure size 
plt.scatter(galaxy_table['col1'], galaxy_table['col2'], 
            color='blue', s=1, edgecolor='None', marker='o')

# the next three lines put text on the figure at the specified coordinates
plt.text(-90, -50, 'LMC', fontsize=20) 
plt.text(-60, -60, 'SMC', fontsize=20)
plt.text(0, -30, 'MW Bulge', fontsize=20)

plt.xlabel('Galactic Longitude (degrees)', 
           fontweight='bold', size=16)
plt.ylabel('Galactic Latitude (degrees)', 
           fontweight='bold', size=16)

plt.xlim([-180,180])
plt.ylim([-90,90])

plt.xticks(range(-180,210,60), fontsize=16, fontweight='bold') # change tick spacing, font size and bold
plt.yticks(range(-90,120,30), fontsize=16, fontweight='bold')
plt.minorticks_on() # turn on minor tick marks 
plt.grid() # turn on a background grip to guide the eye 
plt.show()

Histograms

Histograms can be a great way to visualize data, and they are (surprise) easy to make it python! The basic command is

num, bins, patches = plt.hist(array, bins=number)

Num refers to the number of elements in each bin, and bins refers to each bin on the x-axis. Note that bins actually gives you bin EDGES, so there will always be num+1 number of bins. We can ignore patches for now. As arguments plt.hist() takes an array and the number of bins you would like (default is bins=10). Some other optional arguments for plt.hist are:

  • range: lower and upper range of bins
  • normed: set to 'True' or 'False.' If true it will return a normalized probability distribution instead of just raw number counts for the y-axis.
  • histtype: can be step to something like 'step', 'stepfilled', or 'bar' for the histogram style.
  • weights: an array of values that must be of the same size as the number of bins. It controls the factor by which the number counts are weighted, i.e. it makes your number counts into number_counts*weight.

In [ ]:
# plots histogram where the y-axis is counts 
x = np.random.randn(10000)
num, bins, patches = plt.hist(x, bins=50)
plt.xlabel('Bins')
plt.ylabel('Counts')
plt.show()

In [ ]:
# plots histogram where the y-axis is a probability distribution 
plt.hist(x, bins=50, normed=True)
plt.xlabel('Bins')
plt.ylabel('Probability')
plt.show()

In [ ]:
# plots a histogram where the y-axis is a fraction of the total 
weights = np.ones_like(x)/len(x)
plt.hist(x, bins=50, weights=weights)
plt.ylabel('Fraction')
plt.xlabel('Bins')
plt.show()

In [ ]:
# print out num and bins and see what they look like! what size is each array?
# how would you plot this histogram using plt.plot? what is the x value and what is the y value?

Subplots

Subplots are a way put multiple plots in what amounts to the same figure; think of subplots like an array of plots! The following picture is helpful for understanding how matplotlib places subplots based on row, column, and figure number:


In [ ]:
# make two side by side plots 

x1 = np.linspace(0.0, 5.0)
x2 = np.linspace(0.0, 2.0)

y1 = np.cos(2 * np.pi * x1) * np.exp(-x1)
y2 = np.cos(2 * np.pi * x2)

plt.figure(figsize=[15,3])
plt.subplot(1,2,1) # 1 row, 2 columns, 1st figure 
plt.plot(x1,y1)
plt.xlabel('Xlabel')
plt.ylabel('Ylabel')

plt.subplot(1,2,2) # 1 row, 2 columsn, 2nd figure 
plt.plot(x2,y2)
plt.xlabel('Xlabel')
plt.ylabel('Ylabel')
plt.show()

In [ ]:
# stack two plots on top of one another

plt.subplot(2,1,1) # 1 row, 2 columns, 1st figure 
plt.plot(x1,y1)
plt.xlabel('Xlabel')
plt.ylabel('Ylabel')

plt.subplot(2,1,2) # 1 row, 2 columsn, 2nd figure 
plt.plot(x2,y2)
plt.xlabel('Xlabel')
plt.ylabel('Ylabel')
plt.show()

You can do fancier things with subplots like have different plots share the same axis, put smaller plots as insets to larger plots, etc. Again, take a look at things like the matplotlib library for examples of different plots.

Plotting Exoplanets

Let's try to make some plots with a new dataset. The file that we'll use is taken from exoplanets.eu.


In [ ]:
# don't worry about this way to read in files right now 
import pandas as pd 
exoplanets = pd.read_csv('data/exoplanet.eu_catalog_1022.csv')

In [ ]:
# get rid of some rows with missing values to be safe
exoplanets = exoplanets[np.isfinite(exoplanets['orbital_period'])]

In [ ]:
# let's see what the data table looks like
exoplanets.head()

In [ ]:
# plot distance from host star versus mass (in jupiter masses) for each exoplanet 
plt.loglog(exoplanets['semi_major_axis'], exoplanets['mass'],'.')
plt.annotate("Earth", xy=(1,1/317.), size=12)
plt.annotate("Jupiter", xy=(5,1), size=12)
plt.xlabel('Semi-Major Axis [AU]',size=20)
plt.ylabel('Mass [M$_{Jup}$]', size=20)

In [ ]:
# let's try to find out if the blobs above separate out by detection type
import seaborn as sns; sns.set()
transits = exoplanets[exoplanets['detection_type'] == 'Primary Transit']
radial_vel = exoplanets[exoplanets['detection_type'] == 'Radial Velocity']
imaging = exoplanets[exoplanets['detection_type'] == 'Imaging']
ttv = exoplanets[exoplanets['detection_type'] == 'TTV']
plt.loglog(transits['semi_major_axis'], transits['mass'], '.', label='Transit',markersize=12)
plt.loglog(radial_vel['semi_major_axis'], radial_vel['mass'], '.', label='Radial Vel', markersize=12)
plt.loglog(imaging['semi_major_axis'], imaging['mass'], '.', label='Direct Imaging', markersize=16)
plt.loglog(ttv['semi_major_axis'], ttv['mass'], '.', label='TTV', color='cyan', markersize=16)
plt.annotate("Earth", xy=(1,1/317.), size=12)
plt.annotate("Jupiter", xy=(5,1), size=12)
plt.xlabel('Semi-Major Axis [AU]', size=20)
plt.ylabel('Mass [M$_{Jup}$]', size=20)
plt.legend(loc=4, prop={'size':16})

In [ ]:
# and now just for fun an xkcd style plot! 
plt.xkcd()
plt.scatter(exoplanets['discovered'], exoplanets['radius']*11)
plt.xlabel('Year Discovered')
plt.ylabel('Radius [R_Earth]')

Hertzsprung-Russell Diagram

A Hertzsprung-Russell (HR) diagram is a type of scatter plot that is widely used by astronomers. On the x-axis is the color (B-V value), or effective temperature, and on the y-axis is the absolute magnitude, or luminosity. In other words, it is a plot of a star's brightness (y-axis) against its temperature (x-axis). As you can see in the plot below, temperature increases to the left, and brightness increases as you move up the y-axis. Note that this is an idealized depiction of an HR diagram. Observations will always show more scatter!

It is also apparent that when plotted in this way, stars fall into distinct groups. The most prominent group is the line extending diagonally across the plot. This is referred to as the "main sequence," and is the stage of stellar evolution where stars spend most of their time. During this stage stars are fusing hydrogen into helium in their cores. When stars eventually move off the main sequence they undergo different pathways of evolution depending on their masses. Some stars will fuse all the way up to iron in their cores, while some, like our sun, will only get to fusing helium.


In [ ]: