Fitting Data


In [ ]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import QTable

Curve Fitting - Polynomial


In [ ]:
D1 = QTable.read('./MyData/fit_data1.csv', format='ascii.csv')
D1[0:2]

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(D1['x'],D1['y'],marker="o",linestyle="None",markersize=5);

First-order fitting: $f(x) = ax + b$


In [ ]:
Fit1 = np.polyfit(D1['x'],D1['y'],1)

Fit1       # The coefficients of the fit (a,b)

In [ ]:
Yfit = np.polyval(Fit1,D1['x'])   # The polynomial of Fit1 applied to the points D1['x']

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(D1['x'], D1['y'], marker="o", linestyle="None", markersize=5)

ax.plot(D1['x'], Yfit, linewidth=4, color='c', linestyle='--');

You can use np.poly1d to explore the fitted polynomial


In [ ]:
poly01 = np.poly1d(Fit1)

In [ ]:
poly01(5)                 # value of f(x) at x = 5

In [ ]:
poly01.roots              # value of x at f(x) = 0

In [ ]:
(poly01 - 20).roots       # value of x at f(x) = 20

Second-order fitting: $f(x) = ax^2 + bx + c$


In [ ]:
D2 = QTable.read('./MyData/fit_data2.csv', format='ascii.csv')

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(D2['x'],D2['y'],marker="o",linestyle="None",markersize=5);

In [ ]:
Fit2 = np.polyfit(D2['x'],D2['y'],2)

Fit2

In [ ]:
Yfit = np.polyval(Fit2,D2['x'])

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(D2['x'], D2['y'], marker="o", linestyle="None", markersize=5)

ax.plot(D2['x'], Yfit, linewidth=3, color='y', linestyle='--');

Explore the fitted polynomial


In [ ]:
poly02 = np.poly1d(Fit2)

In [ ]:
poly02(5)                 # value of f(x) at x = 5

In [ ]:
poly02.roots              # value of x at f(x) = 0

In [ ]:
(poly02 - 20).roots       # value of x at f(x) = 20

In [ ]:
(poly02 - 80).roots       # value of x at f(x) = 80, no real root

Be careful! Very high-order fits may be garbage


In [ ]:
Fit3 = np.polyfit(D1['x'],D1['y'],10)

xx = np.linspace(0,10,200)

Yfit = np.polyval(Fit3,xx)

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.set_ylim(-20,120)

ax.plot(D1['x'], D1['y'], marker="o", linestyle="None", markersize=8)
ax.plot(xx, Yfit, linewidth=3, color='m', linestyle='--');

Fitting a specific function


In [ ]:
D3 = QTable.read('./MyData/fit_data3.csv', format='ascii.csv')

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

In [ ]:
from scipy.optimize import curve_fit
$$ \Large f(x) = a \sin(bx) $$

In [ ]:
def ringo(x,a,b):
    return a*np.sin(b*x)

In [ ]:
Aguess = 75
Bguess = 1.0/5.0

fitpars, error = curve_fit(ringo,D3['x'],D3['y'],p0=[Aguess,Bguess])

# Function to fit = ringo
# X points to fit = D3['x']
# Y points to fit = D3['y']
# Initial guess at values for a,b = [Aguess,Bguess]

print(fitpars)

In [ ]:
Z = np.linspace(0,100,1000)

fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(Z, ringo(Z, *fitpars), 'r-')
ax.plot(Z, ringo(Z,Aguess,Bguess), 'g--')

ax.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

Bad initial guesses can lead to very bad fits


In [ ]:
Aguess = 35
Bguess = 1.0

fitpars, error = curve_fit(ringo,D3['x'],D3['y'],p0=[Aguess,Bguess])

print(fitpars)

In [ ]:
fig,ax = plt.subplots(1,1)
fig.set_size_inches(6,4)

fig.tight_layout()

ax.set_xlabel("This is X")
ax.set_ylabel("This is Y")

ax.plot(Z, ringo(Z, *fitpars), 'r-')
ax.plot(Z, ringo(Z,Aguess,Bguess), 'g--')

ax.plot(D3['x'],D3['y'],marker="o",linestyle="None",markersize=5);

Pandas - Plotting and data exploration


In [ ]:
import pandas as pd

In [ ]:
strange_data = pd.read_csv('./MyData/anscombe.csv')

In [ ]:
strange_data.describe()

In [ ]:
strange_data.plot.scatter(x='x1', y='y1', color = 'r', s = 50);

In [ ]:
strange_data.plot.scatter(x='x1', y='y1', color = 'r', s = 50)
strange_data.plot.scatter(x='x2', y='y2', color = 'b', s = 50)
strange_data.plot.scatter(x='x3', y='y3', color = 'g', s = 50)
strange_data.plot.scatter(x='x4', y='y4', color = 'k', s = 50);

In [ ]:
ax = strange_data.plot.scatter(x='x1', y='y1', color = 'r', s = 50, label = "set 1")

strange_data.plot.scatter(x='x2', y='y2', color = 'b', s = 50, label = "set 2", ax = ax)
strange_data.plot.scatter(x='x3', y='y3', color = 'g', s = 50, label = "set 3", ax = ax)
strange_data.plot.scatter(x='x4', y='y4', color = 'k', s = 50, label = "set 4", ax = ax);

Pandas is very good at time series data


In [ ]:
stock_data = pd.read_csv('./MyData/goog.csv')

In [ ]:
stock_data[0:1]

In [ ]:
stock_data.plot(x='Date', y='Close', color = 'r');

In [ ]: