In [6]:
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
%matplotlib inline
from time import process_time
In [7]:
def example(m, n):
A = np.vander(np.linspace(-1., 1., m), n)
b = A.dot(np.ones(n))
return A, b
In [8]:
def solve_lsq(A, b, method='qr'):
if method == 'qr':
Q, R = linalg.qr(A,mode='economic')
k = R.shape[0]
y = np.dot(Q.transpose(),b)
x = linalg.solve_triangular(R,y[:k])
return x
elif method == 'normal':
try:
ATA = np.dot(A.transpose(),A)
ATb = np.dot(A.transpose(),b)
L = linalg.cholesky(ATA,lower=True)
y = linalg.solve_triangular(L,ATb,lower=True)
x = linalg.solve_triangular(L.transpose(),y)
except LinAlgError:
A = 10.0*A
ATA = np.dot(A.tranpose(),A)
ATb = np.dot(A,b)
L = linalg.cholesky(ATA,lower=True)
y = linalg.solve_triangular(L,ATb,lower=True)
x = linalg.solve_triangular(L.transpose(),y)
return x
elif method == 'svd':
U, S, VT = linalg.svd(A)
S = 1/S
S = np.diag(S)
Sigma = np.zeros(A.transpose().shape)
Sigma[:S.shape[0],:S.shape[1]] = S
x = np.dot(U.transpose(),b)
x = np.dot(Sigma,x)
x = np.dot(VT.transpose(),x)
return x
else:
print("doof")
In [13]:
m = 100
times = np.zeros(m)
for n in range(1, m):
A, b = example(m, n)
t_start = process_time()
x = solve_lsq(A, b, method='svd')
times[n] = process_time() - t_start
In [14]:
plt.plot(times)
Out[14]:
In [ ]:
In [ ]: