Plotting and Fitting with Python

matplotlib is the main plotting library for Python


In [ ]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import QTable

Simple Plotting


In [ ]:
t = np.linspace(0,2,100)               # 100 points linearly spaced between 0.0 and 2.0
s = np.cos(2*np.pi*t) * np.exp(-t)     # s if a function of t

In [ ]:
plt.plot(t,s)

Simple plotting - with style

The default style of matplotlib is a bit lacking in style. Some would term it ugly. The new version of matplotlib has added some new styles that you can use in place of the default. Changing the style will effect all of the rest of the plots on the notebook.

Examples of the various styles can be found here


In [ ]:
plt.style.available

In [ ]:
plt.style.use('ggplot')

In [ ]:
plt.plot(t,s)

In [ ]:
plt.xlabel('time (s)')
plt.ylabel('voltage (mV)')
plt.title('This is a title')

plt.ylim(-1.5,1.5)

plt.plot(t, s, color='b', marker='None', linestyle='--');   # adding the ';' at then suppresses the Out[] line

In [ ]:
mask1 = np.where((s>-0.4) & (s<0))

plt.plot(t, s, color='b', marker='None', linestyle='--')

plt.plot(t[mask1],s[mask1],color="g",marker="o",linestyle="None",markersize=8);
Colors Markers Linestyles 'b' blue 's' square '-' Solid line 'g' green 'o' circle '--' Dashed 'r' red '^' triangle ':' Dotted 'c' cyan '+' plus '-.' Dash Dot 'm' magenta '.' dot 'y' yellow '*' star 'k' black 'D' diamond size in pts (72 pt = 1 in) 'w' white 'x' cross

In addition, you can specify colors in many different ways:

  • Grayscale intensities: color = '0.8'
  • RGB triplets: color = (0.3, 0.1, 0.9)
  • RGB triplets (with transparency): color = (0.3, 0.1, 0.9, 0.4)
  • Hex strings: color = '#7ff00'
  • HTML color names: color = 'Chartreuse'
  • a name from the xkcd color survey prefixed with 'xkcd:' (e.g., 'xkcd:sky blue')

matplotlib will work with Astropy units


In [ ]:
from astropy import units as u

from astropy.visualization import quantity_support
quantity_support()

In [ ]:
v = 10 * u.m / u.s
t2 = np.linspace(0,10,1000) * u.s
y = v * t2

In [ ]:
plt.plot(t2,y)

Simple Histograms


In [ ]:
#Histogram of "h" with 20 bins

np.random.seed(42)
h = np.random.randn(500)

plt.hist(h, bins=20, facecolor='MediumOrchid');

In [ ]:
mask2 = np.where(h>0.0)

np.random.seed(42)
j = np.random.normal(2.0,1.0,300)  # normal dist, ave = 2.0, std = 1.0

plt.hist(h[mask2], bins=20, facecolor='#b20010', histtype='stepfilled')
plt.hist(j,        bins=20, facecolor='#0200b0', histtype='stepfilled', alpha = 0.30);

You have better control of the plot with the object oriented interface.

While most plt functions translate directly to ax methods (such as plt.plot()ax.plot(), plt.legend()ax.legend(), etc.), this is not the case for all commands. In particular, functions to set limits, labels, and titles are slightly modified. For transitioning between matlab-style functions and object-oriented methods, make the following changes:

  • plt.xlabel()ax.set_xlabel()
  • plt.ylabel()ax.set_ylabel()
  • plt.xlim()ax.set_xlim()
  • plt.ylim()ax.set_ylim()
  • plt.title()ax.set_title()

In [ ]:
fig,ax = plt.subplots(1,1)                    # One window
fig.set_size_inches(11,8.5)                   # (width,height) - letter paper landscape

fig.tight_layout()                          # Make better use of space on plot

ax.set_xlim(0.0,1.5)

ax.spines['bottom'].set_position('zero')    # Move the bottom axis line to x = 0

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(t, s, color='b', marker='None', linestyle='--')

ax.text(0.8, 0.6, 'Bad Wolf', color='green', fontsize=36)            # You can place text on the plot

ax.vlines(0.4, -0.4, 0.8, color='m', linewidth=3)                    # vlines(x, ymin, ymax)
ax.hlines(0.8,  0.2, 0.6, color='y', linewidth=5)                    # hlines(y, xmin, xmax)

In [ ]:
fig.savefig('fig1.png', bbox_inches='tight')

In [ ]:
import glob

glob.glob('*.png')

Plotting from multiple external data files


In [ ]:
data_list = glob.glob('./MyData/12_data*.csv')

data_list

In [ ]:
fig,ax = plt.subplots(1,1)                    # One window
fig.set_size_inches(11,8.5)                   # (width,height) - letter paper landscape

fig.tight_layout()                          # Make better use of space on plot

ax.set_xlim(0.0,80.0)
ax.set_ylim(15.0,100.0)

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

for file in data_list:
    
    data = QTable.read(file, format='ascii.csv')
    ax.plot(data['x'], data['y'],marker="o",linestyle="None",markersize=7,label=file)
    
ax.legend(loc=0,shadow=True);

Legend loc codes:

0 best 6 center left 1 upper right 7 center right 2 upper left 8 lower center 3 lower left 9 upper center 4 lower right 10 center

Subplots

  • subplot(rows,columns)
  • Access each subplot like a matrix. [x,y]
  • For example: subplot(2,2) makes four panels with the coordinates:
[0,0] [0,1] [1,0] [1,1]

In [ ]:
fig, ax = plt.subplots(2,2)                                               # 2 rows 2 columns
fig.set_size_inches(11,8.5)                                               # width, height

fig.tight_layout()                                                        # Make better use of space on plot

ax[0,0].plot(t, s, color='b', marker='None', linestyle='--')              # Plot at [0,0]

ax[0,1].hist(h, bins=20, facecolor='MediumOrchid')                        # Plot at [0,1]

ax[1,0].hist(j,bins=20, facecolor='HotPink', histtype='stepfilled')       # Plot at [1,0]
ax[1,0].vlines(2.0, 0.0, 50.0, color='xkcd:seafoam green', linewidth=3)

ax[1,1].set_xscale('log')                                                 # Plot at [1,1] - x-axis set to log
ax[1,1].plot(t, s, color='r', marker='None', linestyle='--');

An Astronomical Example - Color Magnitude Diagrams


In [ ]:
T = QTable.read('M15_Bright.csv', format='ascii.csv')
T[0:3]

In [ ]:
fig, ax = plt.subplots(1,2)                 # 1 row, 2 colums
fig.set_size_inches(15,5)

fig.tight_layout()

# The plot for [0]
#  Notice that for a single row of plots you do not need to specify the row

ax[0].set_xlim(-40,140)
ax[0].set_ylim(-120,120)

ax[0].set_aspect('equal')                    # Force intervals in x = intervals in y
ax[0].invert_xaxis()                         # RA increases to the left!

ax[0].set_xlabel("$\Delta$RA [sec]")
ax[0].set_ylabel("$\Delta$Dec [sec]")

ax[0].plot(T['RA'], T['Dec'],color="g",marker="o",linestyle="None",markersize=5);

# The plot for [1]

BV = T['Bmag'] - T['Vmag']
V = T['Vmag']

ax[1].set_xlim(-0.25,1.5)
ax[1].set_ylim(12,19)

ax[1].set_aspect(1/6)         # Make 1 unit in X = 6 units in Y
ax[1].invert_yaxis()          # Magnitudes increase to smaller values

ax[1].set_xlabel("B-V")
ax[1].set_ylabel("V")

ax[1].plot(BV,V,color="b",marker="o",linestyle="None",markersize=5);

# overplotting

maskC = np.where((V < 16.25) & (BV < 0.55))     

ax[0].plot(T['RA'][maskC], T['Dec'][maskC],color="r",marker="o",linestyle="None",markersize=4, alpha=0.5)
ax[1].plot(BV[maskC], V[maskC],color="r",marker="o",linestyle="None",markersize=4, alpha=0.5);

Curve Fitting


In [ ]:
D1 = QTable.read('data1.csv', format='ascii.csv')
D1[0:2]

In [ ]:
plt.plot(D1['x'],D1['y'],marker="o",linestyle="None",markersize=5);

In [ ]:
# 1-D fit y = ax + b

Fit1 = np.polyfit(D1['x'],D1['y'],1)

Fit1       # The coefficients of the fit (a,b)

In [ ]:
Yfit = np.polyval(Fit1,D1['x'])   # The polynomial of Fit1 applied to the points D1['x']

plt.plot(D1['x'], D1['y'], marker="o", linestyle="None", markersize=5)

plt.plot(D1['x'], Yfit, linewidth=4, color='c', linestyle='--')

In [ ]:
D2 = QTable.read('data2.csv', format='ascii.csv')

In [ ]:
plt.plot(D2['x'],D2['y'],marker="o",linestyle="None",markersize=5);

In [ ]:
# 2-D fit y = ax**2 + bx + c

Fit2 = np.polyfit(D2['x'],D2['y'],2)

Fit2

In [ ]:
Yfit = np.polyval(Fit2,D2['x'])

plt.plot(D2['x'], D2['y'], marker="o", linestyle="None", markersize=5)

plt.plot(D2['x'], Yfit, linewidth=3, color='y', linestyle='--');

In [ ]:
# Be careful, very high-order fits may be garbage

Fit3 = np.polyfit(D1['x'],D1['y'],20)

xx = np.linspace(0,10,200)

Yfit = np.polyval(Fit3,xx)

plt.plot(D1['x'], D1['y'], marker="o", linestyle="None", markersize=8)
plt.plot(xx, Yfit, linewidth=3, color='m', linestyle='--');

plt.ylim(-20,120)

Fitting a specific function


In [ ]:
D3 = QTable.read('data3.csv', format='ascii.csv')

plt.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

In [ ]:
from scipy.optimize import curve_fit
$$ \Large f(x) = a \sin(bx) $$

In [ ]:
def ringo(x,a,b):
    return a*np.sin(b*x)

In [ ]:
Aguess = 75
Bguess = 1.0/5.0

fitpars, error = curve_fit(ringo,D3['x'],D3['y'],p0=[Aguess,Bguess])

# Function to fit = ringo
# X points to fit = D3['x']
# Y points to fit = D3['y']
# Initial guess at values for a,b = [Aguess,Bguess]

print(fitpars)

In [ ]:
Z = np.linspace(0,100,1000)

plt.plot(Z, ringo(Z, *fitpars), 'r-')
plt.plot(Z, ringo(Z,Aguess,Bguess), 'g--')

plt.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

Bad initial guesses can lead to very bad fits


In [ ]:
Aguess = 35
Bguess = 1.0

fitpars, error = curve_fit(ringo,D3['x'],D3['y'],p0=[Aguess,Bguess])

print(fitpars)

plt.plot(Z, ringo(Z, *fitpars), 'r-')
plt.plot(Z, ringo(Z,Aguess,Bguess), 'g--')

plt.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

Polar Plots


In [ ]:
theta = np.linspace(0,2*np.pi,1000)


fig = plt.figure()
ax = fig.add_subplot(111,projection='polar')

fig.set_size_inches(6,6)                   # (width,height) - letter paper landscape

fig.tight_layout()                          # Make better use of space on plot

ax.plot(theta,theta/5.0,label="spiral")
ax.plot(theta,np.cos(4*theta),label="flower")

ax.legend(loc=2, frameon=False);

Everyone likes Pie


In [ ]:
fig,ax = plt.subplots(1,1)                    # One window
fig.set_size_inches(6,6)                   # (width,height) - letter paper landscape

fig.tight_layout()                          # Make better use of space on plot

ax.set_aspect('equal')

labels = np.array(['John', 'Paul' ,'George' ,'Ringo'])     # Name of slices
sizes = np.array([0.3, 0.15, 0.45, 0.10])                  # Relative size of slices
colors = np.array(['r', 'g', 'b', 'c'])                    # Color of Slices
explode = np.array([0, 0, 0.1, 0])                         # Offset slide 3

ax.pie(sizes, explode=explode, labels=labels, colors=colors,
       startangle=90, shadow=True);

3D plots


In [ ]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')

fig.set_size_inches(9,9)

fig.tight_layout()

xx = np.cos(3*theta)
yy = np.sin(2*theta)

ax.plot(theta, xx, yy, c = "Maroon")
ax.scatter(theta, xx, yy, c = "Navy", s = 15);

ax.view_init(azim = -140, elev = 15)

Tons of examples of matplotlib plots can be found here