PYT-DS: Subplots in Matplotlib

The VanderPlas Syllabus is one of the more useful and core to this course in many ways.

Jake VanderPlas has been a key player in helping to promote open source. He's an astronomer by training. Lets remember that the Space Telescope Institute, home for the Hubble space telescope, in terms of management and data wrangling, has been a big investor in Python technologies, as well as matplotlib.

It was my great privilege to deliver a Python training at SScI in behalf of Holdenweb, a Steve Holden company.


In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt'seaborn-white')
import numpy as np

The first distinction to make is between Figure, which is the outer frame of a canvas, and the rectangular XY grids or coordinate systems we place within the figure. XY grid objects are known as "axes" (plural) and most of the attributes we associate with plots are actually connected to axes.

What may be confusing to the new user is that plt (pyplot) keeps track of which axes we're working on, so we have ways of communicating with axes through plt which obscures the direct connection twixt axes and their attributes.

Below we avoid using plt completely except for initializing our figure, and manage to get two sets of axes (two plots inside the same figure).

In [ ]:
fig = plt.figure("main", figsize=(5,5)) # name or int id optional, as is figsize
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4],
                   xticklabels=[], ylim=(-1.2, 1.2)) # no x axis tick marks
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4],
                   ylim=(-1.2, 1.2))

x = np.linspace(0, 10)
_ = ax2.plot(np.cos(x)) # assign to dummy variable to suppress text output

In [ ]:

Here's subplot in action, creating axes automatically based on how many rows and columns we specify, followed by a sequence number i, 1 through however many (in this case six).

Notice how plt is keeping track of which subplot axes are current, and we talk to said axes through plt.

In [ ]:
for i in range(1, 7):
    plt.subplot(2, 3, i)
    plt.text(0.5, 0.5, str((2, 3, i)),
             fontsize=18, ha='center')
    plt.xticks([]) # get rid of tickmarks on x axis
    plt.yticks([]) # get rid of tickmarks on y axis

Here we're talking to the axes objects more directly by calling "get current axes". Somewhat confusingly, the instances return have an "axes" attribute which points to the same instance, a wrinkle I explore below. Note the slight difference between the last two lines.

In [ ]:
for i in range(1, 7):
    plt.subplot(2, 3, i)
    plt.text(0.5, 0.5, str((2, 3, i)),
             fontsize=18, ha='center')
    # synonymous.  gca means 'get current axes'
    plt.gca().axes.get_xaxis().set_visible(False) # axes optional, self referential
plt.gcf().subplots_adjust(hspace=0.1, wspace=0.1) # get current figure, adjust spacing


You might need to install pillow to get the code cells to work. Pillow is a Python 3 fork of PIL, the Python Imaging Library, still imported using that name.

conda install pillow from the most compatible repo for whatever Anaconda environment you're using would be one way to get it. Using pip would be another. The face.png binary is in your course folder for this evening.

Question: Might we using axes to show images?

Answer: Absolutely, as matplotlib axes have an imshow method.

In [ ]:
from PIL import Image # Image is a module!
plt.subplot(1, 2, 1)
plt.xticks([]) # get rid of tickmarks on x axis
plt.yticks([]) # get rid of tickmarks on y axis

im ="face.png")

plt.subplot(1, 2, 2)
plt.xticks([]) # get rid of tickmarks on x axis
plt.yticks([]) # get rid of tickmarks on y axis

# rotate 180 degrees
rotated = im.transpose(Image.ROTATE_180)
_ = plt.gcf().tight_layout()

The script below, borrowed from the matplotlib gallery, shows another common idiom for getting a figure and axes pair. Call plt.suplots with no arguments. Then talk to ax directly, for the most part. We're also rotating the x tickmark labels by 45 degrees. Fancy!

Uncommenting the use('classic') command up top makes a huge difference in the result. I'm still trying to figure that out.

In [ ]:
vegetables = ["cucumber", "tomato", "lettuce", "asparagus",
              "potato", "wheat", "barley"]
farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening",
           "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]

harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
                    [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
                    [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
                    [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
                    [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
                    [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
                    [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])

fig, ax = plt.subplots()
im = ax.imshow(harvest)

# We want to show all ticks...
# ... and label them with the respective list entries

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",

# Loop over data dimensions and create text annotations.
for i in range(len(vegetables)):
    for j in range(len(farmers)):
        text = ax.text(j, i, harvest[i, j],
                       ha="center", va="center", color="y")

ax.set_title("Harvest of local farmers (in tons/year)")